{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V6iqZHPjN2S0"
      },
      "source": [
        "# License\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at:\n",
        "\n",
        "https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SNwvSD5wN7Er"
      },
      "source": [
        "# Instructions\n",
        "\n",
        "This Notebook allows to reproduce the experiments reported in the publication titled:\n",
        "\n",
        "\"[*Multipath Agents for Modular Multitask ML Systems*](https://arxiv.org/abs/2302.02721)\" (2023)\n",
        "\n",
        "---\n",
        "To start an experiment:\n",
        "---\n",
        "1. Choose the agent type by setting the `AGENT` variable in the configuration below.\n",
        "Select `Vit`  in order to choose the singlepath agent. Select `MultiVit` in order to the multipath agent.\n",
        "These agents use a ViT-Large root model.\n",
        "Set agent types suffixed with `T3` to use a ViT-Tiny root model capped to 3 layers. Set agent types suffixed with `B` to use a ViT-Base root model.\n",
        "\n",
        "1. Set `TASK_NAME` to the string-id of the task assigned to the instantiation of the selected agent.\n",
        "Refer to `TFDS_IMAGE_CLASSIFCATON_DATASETS` and `VTAB_TASKS` (below) for lists of tasks ids that have been tested with the current code. \n",
        "These lists contain task ids from the [Tensorflow Datasets Catalog](https://www.tensorflow.org/datasets/catalog/overview).\n",
        "Note that some tasks require manual download, refer to the corresponding catalog page for instructions. **WARNING**: The system state needs to be populated with at least one root model before running an agent training on any task. In order to generate the root model, set `TASK_NAME` to either `\"root_model/checkpoint\"` or `\"root_model/random_init\"` for respectively loading a pretrained root model or generating a randomly initialized one.\n",
        "\n",
        "1. Set `NUM_CYCLES_MAX` to the desired number of evolutionary cycles. Additional configuration parameters can be modified in the Agents definitions code below. Configurations are set to the settings described in the publication.\n",
        "\n",
        "1. Set `SYSTEM_STATE_RELATIVE_DIR` to a relative path from where the system state will be read and written.\n",
        "\n",
        "1. By default, the system state is stored under a temporary folder within the Virtual Machine (VM) memory. This temporary folder is deleted when the VM is stopped or restarted.\n",
        "It is possible to store the system state folder on your Google Drive by activating the\n",
        "`SAVE_SYSTEM_STATE_ON_GOOGLE_DRIVE` option. In this case, you will be prompted for access approval and the system state folder will be saved in a folder named `\"munet_experiments\"` under your Google Drive root folder. Furthermore, it is also possible to store the system state into a Google Drive folder shared with multiple users by creating a link to the shared folder into your Google Drive and then setting `GDRIVE_ROOT_DIR` (below) to the path of the linked shared folder. \n",
        "\n",
        "1. To start the experiment, select \"Connect to a hosted runtime\" from the dropdown menu on the top right, and then select \"Run all\" from the \"Runtime\" menu. A free hosted CPU runtime is always available. Free access to GPU and TPU accelerators are occasionally provided by the service depending on availability.\n",
        "\n",
        "---\n",
        "During the experiment execution:\n",
        "---\n",
        "\n",
        "1. The print output is displayed after the last cell of this Colab.\n",
        "\n",
        "1. The system state folder is populated with a subfolder for each agent.\n",
        "The name of each agent folder is prefixed with the `AGENT` type string and suffixed with the `TASK_NAME`.\n",
        "Each agent directory is populated with incremental state subfolders  containing the sharded state of the architectures and parameters generated by the agent during the corresponding evolutionary cycle.\n",
        "\n",
        "1. Agents can be started asynchronously and run in parallel in varying quantities.\n",
        "It is possible to resume an interrupted agent training by restarting the execution with the same configuration.\n",
        "It is possible to continue a completed training by increasing `NUM_CYCLES_MAX`.\n",
        "\n",
        "1. To achieve a multi-agent execution, multiple Colabs need to be run in parallel, each set to the same configuration but different `TASK_NAME`.\n",
        "\n",
        "1. To achieve heterogeneous hardware execution, parallel Colab Notebooks can be connected to a runtime of different types.\n",
        "It is possible to switch between CPU, GPU and TPU by selecting `Change runtime type` in the `Resources` tab in this Colab Notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M93tll7z29rX"
      },
      "outputs": [],
      "source": [
        "# @title Agent parameters\n",
        "AGENT = \"VitT3\" # @param [\"VitT3\", \"VitB\", \"Vit\", \"MultiVitT3\", \"MultiVitB\", \"MultiVit\"] { type: \"string\", isTemplate: true }\n",
        "# Set TASK_NAME to \"root_model/checkpoint\" or \"root_model/random_init\" to initalize the population.\n",
        "TASK_NAME = \"root_model/checkpoint\"  # @param { type: \"string\", isTemplate: true }\n",
        "NUM_CYCLES_MAX = 1 # @param { type: \"integer\", isTemplate: true }\n",
        "SYSTEM_STATE_RELATIVE_DIR = \"munet_system_state/\"  # @param { type: \"string\", isTemplate: true }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xcaI2KII_MH5"
      },
      "outputs": [],
      "source": [
        "# Saves system state on Google drive instead of saving it in a temporary VM folder.\n",
        "SAVE_SYSTEM_STATE_ON_GOOGLE_DRIVE = False  # @param { type: \"boolean\", isTemplate: true }\n",
        "if SAVE_SYSTEM_STATE_ON_GOOGLE_DRIVE:\n",
        "  from google.colab import drive\n",
        "  drive.mount('/content/gdrive')\n",
        "  GDRIVE_ROOT_DIR = \"/content/gdrive/My Drive/munet_experiments/\"\n",
        "  SYSTEM_STATE_DIR = GDRIVE_ROOT_DIR + SYSTEM_STATE_RELATIVE_DIR\n",
        "  print(\"Saving system state in Google Drive.\")\n",
        "else:\n",
        "  SYSTEM_STATE_DIR = \"/tmp/\" + SYSTEM_STATE_RELATIVE_DIR\n",
        "  print(\"WARNING: Saving system state in VM, state will be lost after reboot!\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dsiDs_mgBZOx"
      },
      "outputs": [],
      "source": [
        "# Test immutability of published paths at beginning of each cycle.\n",
        "# Tollerance may be increased if the system is run with any context difference: e.g. harware, input preprocessing, libraries or datsets version.\n",
        "TEST_IMMUTABILITY = False\n",
        "IMMUTABILITY_RELATIVE_TOLLERANCE = 0.001  # 0.1%"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KZ0njfC-XCBA"
      },
      "source": [
        "# Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G9CQIJwcYN2N"
      },
      "outputs": [],
      "source": [
        "!pip install --upgrade -q pip jax jaxlib\n",
        "!pip install --upgrade -q git+https://github.com/google/flax.git\n",
        "!pip install -q ml_collections\n",
        "!pip install -q tensorflow_addons\n",
        "![ -d task_adaptation ] || git clone --depth=1 https://github.com/google-research/task_adaptation\n",
        "![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer\n",
        "\n",
        "import sys\n",
        "if './task_adaptation' not in sys.path:\n",
        "  sys.path.append('./task_adaptation')\n",
        "if './vision_transformer' not in sys.path:\n",
        "  sys.path.append('./vision_transformer')\n",
        "\n",
        "import jax.tools.colab_tpu\n",
        "try:\n",
        "  jax.tools.colab_tpu.setup_tpu()\n",
        "except:\n",
        "  pass  # Not a Tpu"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LzhEweKzMg6k"
      },
      "outputs": [],
      "source": [
        "import copy\n",
        "import flax\n",
        "import flax.linen as nn\n",
        "import gc\n",
        "import inspect\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import json\n",
        "import math\n",
        "import numpy as np\n",
        "import optax\n",
        "import os\n",
        "import pandas as pd\n",
        "import random\n",
        "import re\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "import time\n",
        "from collections import defaultdict\n",
        "from functools import partial\n",
        "from flax.training import checkpoints as flax_checkpoints\n",
        "from ml_collections import ConfigDict, FrozenConfigDict\n",
        "from tensorflow.io import gfile\n",
        "from threading import Thread, Lock\n",
        "from typing import Any, Type\n",
        "\n",
        "tf.compat.v1.enable_eager_execution()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6QHBzcuUYeh5"
      },
      "outputs": [],
      "source": [
        "# ViT imports.\n",
        "from vision_transformer.vit_jax import input_pipeline\n",
        "from vision_transformer.vit_jax import checkpoint\n",
        "from vision_transformer.vit_jax.configs import models as models_config  # Model configurations.\n",
        "from vision_transformer.vit_jax import models_vit as models  # Actual model code.\n",
        "# VTAB imports.\n",
        "import task_adaptation.registry as task_adapt_registry\n",
        "import task_adaptation.data.caltech\n",
        "import task_adaptation.data.cifar\n",
        "import task_adaptation.data.dtd\n",
        "import task_adaptation.data.oxford_flowers102\n",
        "import task_adaptation.data.oxford_iiit_pet\n",
        "import task_adaptation.data.sun397\n",
        "import task_adaptation.data.svhn\n",
        "import task_adaptation.data.patch_camelyon\n",
        "import task_adaptation.data.eurosat\n",
        "import task_adaptation.data.resisc45\n",
        "import task_adaptation.data.diabetic_retinopathy\n",
        "import task_adaptation.data.clevr\n",
        "import task_adaptation.data.dmlab\n",
        "import task_adaptation.data.dsprites\n",
        "import task_adaptation.data.kitti\n",
        "import task_adaptation.data.smallnorb"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zhl1L8ldXFJz"
      },
      "source": [
        "# Utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SHBWj0JmpWDX"
      },
      "outputs": [],
      "source": [
        "# Ref. Tfds catalog: https://www.tensorflow.org/datasets/catalog/overview\n",
        "TFDS_IMAGE_CLASSIFCATON_DATASETS = set([\n",
        "    \"beans\",\n",
        "    \"binary_alpha_digits\",\n",
        "    \"caltech_birds2010\",\n",
        "    \"caltech_birds2011\",\n",
        "    \"cars196\",\n",
        "    \"cassava\",\n",
        "    \"cats_vs_dogs\",\n",
        "    \"cifar10\",\n",
        "    \"cifar100\",\n",
        "    \"citrus_leaves\",\n",
        "    \"cmaterdb/bangla\",\n",
        "    \"cmaterdb/devanagari\",\n",
        "    \"cmaterdb/telugu\",\n",
        "    \"colorectal_histology\",\n",
        "    \"controlled_noisy_web_labels/mini_imagenet_red\",\n",
        "    \"controlled_noisy_web_labels/mini_imagenet_blue\",\n",
        "    \"curated_breast_imaging_ddsm/patches\",\n",
        "    \"cycle_gan/apple2orange\",\n",
        "    \"cycle_gan/summer2winter_yosemite\",\n",
        "    \"cycle_gan/horse2zebra\",\n",
        "    \"cycle_gan/monet2photo\",\n",
        "    \"cycle_gan/cezanne2photo\",\n",
        "    \"cycle_gan/ukiyoe2photo\",\n",
        "    \"cycle_gan/vangogh2photo\",\n",
        "    \"cycle_gan/maps\",\n",
        "    \"cycle_gan/cityscapes\",\n",
        "    \"cycle_gan/facades\",\n",
        "    \"cycle_gan/iphone2dslr_flower\",\n",
        "    \"deep_weeds\",\n",
        "    \"domainnet/real\",\n",
        "    \"domainnet/painting\",\n",
        "    \"domainnet/clipart\",\n",
        "    \"domainnet/quickdraw\",\n",
        "    \"domainnet/infograph\",\n",
        "    \"domainnet/sketch\",\n",
        "    \"emnist/balanced\",\n",
        "    \"emnist/byclass\",\n",
        "    \"emnist/bymerge\",\n",
        "    \"emnist/digits\",\n",
        "    \"emnist/letters\",\n",
        "    \"emnist/mnist\",\n",
        "    \"fashion_mnist\",\n",
        "    \"food101\",\n",
        "    \"horses_or_humans\",\n",
        "    \"i_naturalist2017\",\n",
        "    \"i_naturalist2018\",\n",
        "    \"imagenet2012\",\n",
        "    \"imagenet_a\",\n",
        "    \"imagenet_lt\",\n",
        "    \"imagenet_r\",\n",
        "    \"imagenet_sketch\",\n",
        "    \"imagenette\",\n",
        "    \"imagewang\",\n",
        "    \"kmnist\",\n",
        "    \"malaria\",\n",
        "    \"mnist\",\n",
        "    \"omniglot\",\n",
        "    \"pet_finder\",\n",
        "    \"places365_small\",\n",
        "    \"plant_village\",\n",
        "    \"plantae_k\",\n",
        "    \"quickdraw_bitmap\",\n",
        "    \"rock_paper_scissors\",\n",
        "    \"siscore/rotation\",\n",
        "    \"siscore/size\",\n",
        "    \"siscore/location\",\n",
        "    \"stanford_dogs\",\n",
        "    \"stanford_online_products\",\n",
        "    \"stl10\",\n",
        "    \"tf_flowers\",\n",
        "    \"uc_merced\",\n",
        "    \"visual_domain_decathlon/aircraft\",\n",
        "    \"visual_domain_decathlon/cifar100\",\n",
        "    \"visual_domain_decathlon/daimlerpedcls\",\n",
        "    \"visual_domain_decathlon/dtd\",\n",
        "    \"visual_domain_decathlon/gtsrb\",\n",
        "    \"visual_domain_decathlon/imagenet12\",\n",
        "    \"visual_domain_decathlon/omniglot\",\n",
        "    \"visual_domain_decathlon/svhn\",\n",
        "    \"visual_domain_decathlon/ucf101\",\n",
        "    \"visual_domain_decathlon/vgg-flowers\",\n",
        "    ])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XI7xcfWLXx-G"
      },
      "outputs": [],
      "source": [
        "# Append suffix \"/1k\" to get the 1k version of each task.\n",
        "VTAB_TASKS = [\n",
        "    \"caltech101\",\n",
        "    # cifar100/10 were already added with slightly different val split but same test set. So here is added only the 1k versions.\n",
        "    \"cifar100/1k\",\n",
        "    \"cifar10/1k\",\n",
        "    \"dtd\",\n",
        "    \"oxford_flowers102\",\n",
        "    \"oxford_iiit_pet\",\n",
        "    \"sun397\",\n",
        "    \"svhn_cropped\",\n",
        "    \"patch_camelyon\",\n",
        "    \"eurosat\",\n",
        "    \"resisc45\",\n",
        "    \"diabetic_retinopathy_detection/btgraham-300\",\n",
        "    \"clevr/count_cylinders\",  # Not in results table.\n",
        "    \"clevr/count_all\",  # Clevr-Count\n",
        "    \"clevr/closest_object_distance\",  # Clevr-Dist\n",
        "    \"dmlab\",\n",
        "    \"dsprites/label_x_position\",  # dSpr-Loc\n",
        "    \"dsprites/label_orientation\",  # dSpr-Ori\n",
        "    \"kitti/closest_object_distance\",  # Not in results table.\n",
        "    \"kitti/count_vehicles\",  # Not in results table.\n",
        "    \"kitti/closest_vehicle_distance\",  # Kitti-dist\n",
        "    \"smallnorb/label_category\",  # Not in results table.\n",
        "    \"smallnorb/label_lighting\",  # Not in results table.\n",
        "    \"smallnorb/label_azimuth\",  # Azim\n",
        "    \"smallnorb/label_elevation\",  # Elev\n",
        "    ]\n",
        "for tn in VTAB_TASKS:\n",
        "  assert tn not in TFDS_IMAGE_CLASSIFCATON_DATASETS, tn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OeLYCqhfXZyk"
      },
      "outputs": [],
      "source": [
        "def compute_flops_hlo(flax_module, *a, **kw):\n",
        "  # Compute flops on cpu for cross platform consistency.\n",
        "  analysis = jax.jit(flax_module, backend='cpu').lower(*a, **kw).cost_analysis()\n",
        "  return analysis[\"flops\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IBFJAJe6XmH7"
      },
      "outputs": [],
      "source": [
        "class ObjectCache():\n",
        "  def __init__(self, factory_fn):\n",
        "    self.factory_fn = factory_fn\n",
        "    self.factory_fn_signature = inspect.signature(factory_fn)\n",
        "    self.cache = {}\n",
        "\n",
        "  def __call__(self, *args, **kwargs):\n",
        "    assert not args, \"No positional arguments allowed.\"\n",
        "    kw_params = {}\n",
        "    fn_name = self.factory_fn.__name__\n",
        "    fn_params = inspect.signature(self.factory_fn).parameters\n",
        "    for k_param, v_param in fn_params.items():\n",
        "      if k_param in kwargs:\n",
        "        kw_params[k_param] = kwargs[k_param]\n",
        "      elif v_param.default != v_param.empty:\n",
        "        # Fallback to declared defalut value.\n",
        "        kw_params[k_param] = fn_params[k_param].default\n",
        "      else:\n",
        "        assert False, (\n",
        "            f\"Missing value for argument {k_param} for function {fn_name}\")\n",
        "\n",
        "      if v_param.annotation != v_param.empty:\n",
        "        # Apply annotated type.\n",
        "        assert isinstance(type(v_param.annotation), type)\n",
        "        kw_params[k_param] = v_param.annotation(kw_params[k_param])\n",
        "\n",
        "    key = json.dumps(kw_params, sort_keys=True)\n",
        "    if key not in self.cache:\n",
        "      self.cache[key] = self.factory_fn(**kw_params)\n",
        "      print(f\"Added to cache: {fn_name}({key})  [cache size {len(self.cache)}]\")\n",
        "    return self.cache[key]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AgTEwq24TYn8"
      },
      "source": [
        "# Models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ueAYFssCUG78"
      },
      "outputs": [],
      "source": [
        "# Sample inputs\n",
        "def get_sample_images(image_size, batch_size):\n",
        "  return np.zeros((batch_size, image_size, image_size, 3))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tkv4dVcQV4YB"
      },
      "outputs": [],
      "source": [
        "def get_num_params(params):\n",
        "  return sum(jax.tree_util.tree_flatten(\n",
        "      jax.tree_util.tree_map(lambda p: np.prod(p.shape), params)\n",
        "      )[0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V6vIOQBEVWch"
      },
      "outputs": [],
      "source": [
        "def get_optimizer(\n",
        "    opt_lr: float,\n",
        "    opt_lr_schedule: str,\n",
        "    opt_lr_warmup_ratio: float,\n",
        "    opt_momentum: float,\n",
        "    opt_nesterov: bool,\n",
        "    num_train_batches_between_validations: int,\n",
        "    num_validations_per_path_training: int,\n",
        "    ):\n",
        "  min_lr = opt_lr / 1000.0\n",
        "  if opt_lr_schedule == \"constant\":\n",
        "    # Divide by 2 so that average lr is the same as other types.\n",
        "    learning_rate = 0.5 * opt_lr\n",
        "  elif opt_lr_schedule == \"linear\":\n",
        "    train_steps = int(num_train_batches_between_validations * num_validations_per_path_training)\n",
        "    warmup_steps = int(opt_lr_warmup_ratio * train_steps)\n",
        "    schedules = [\n",
        "        optax.linear_schedule(\n",
        "            init_value=min_lr,\n",
        "            end_value=opt_lr,\n",
        "            transition_steps=warmup_steps),\n",
        "        optax.linear_schedule(\n",
        "            init_value=opt_lr,\n",
        "            end_value=min_lr,\n",
        "            transition_steps=train_steps-warmup_steps)]\n",
        "    learning_rate = optax.join_schedules(schedules, [warmup_steps])\n",
        "  elif opt_lr_schedule == \"cosine\":\n",
        "    train_steps = int(num_train_batches_between_validations\n",
        "                      * num_validations_per_path_training)\n",
        "    learning_rate = optax.warmup_cosine_decay_schedule(\n",
        "        init_value=min_lr,\n",
        "        peak_value=opt_lr,\n",
        "        warmup_steps=int(opt_lr_warmup_ratio * train_steps),\n",
        "        decay_steps=train_steps)\n",
        "  elif opt_lr_schedule == \"restarts\":\n",
        "    train_steps = num_train_batches_between_validations\n",
        "    repeats = num_validations_per_path_training\n",
        "    kwargs = dict(\n",
        "        init_value=min_lr,\n",
        "        peak_value=opt_lr,\n",
        "        warmup_steps=int(opt_lr_warmup_ratio * train_steps),\n",
        "        decay_steps=train_steps,\n",
        "    )\n",
        "    kwargs = [kwargs] * repeats\n",
        "    learning_rate = optax.sgdr_schedule(kwargs)\n",
        "  else:\n",
        "    assert False, f\"Invalid lr schedule: {opt_lr_schedule}\"\n",
        "\n",
        "  return optax.chain(\n",
        "      optax.clip_by_global_norm(1.0),\n",
        "      optax.sgd(\n",
        "          learning_rate=learning_rate,\n",
        "          momentum=opt_momentum,\n",
        "          nesterov=opt_nesterov,\n",
        "          accumulator_dtype=jnp.bfloat16))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2GpeOYxFi1Sm"
      },
      "outputs": [],
      "source": [
        "def merge_params(a, b):\n",
        "  params = a.copy(b)\n",
        "  assert len(params) == len(a) + len(b)\n",
        "  return params"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ExYTVeamDcPy"
      },
      "source": [
        "## Vit Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "43JOxcGVMJ4s"
      },
      "outputs": [],
      "source": [
        "class VitModelFactory():\n",
        "  @staticmethod\n",
        "  def get_model(hparams, config):\n",
        "    return get_vit_model(hparams, config)\n",
        "\n",
        "  @staticmethod\n",
        "  def get_init_comps(hparams, config):\n",
        "    return get_vit_init_comps(hparams, config)\n",
        "\n",
        "  @staticmethod\n",
        "  def get_comps2model_fn():\n",
        "    return vit_comps2model\n",
        "\n",
        "  @staticmethod\n",
        "  def get_sample_input(hparams):\n",
        "    return get_sample_images(image_size=hparams[\"ds_image_size\"], batch_size=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vvZ_4-kJ9Pt3"
      },
      "outputs": [],
      "source": [
        "def get_vit_filename(query):\n",
        "  df = checkpoint.get_augreg_df()\n",
        "  res = df.query(query).filename.unique()\n",
        "  assert len(res) == 1\n",
        "  return res[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0lvqd47g9ZsW"
      },
      "outputs": [],
      "source": [
        "VIT_CONFIG_CACHE = {}\n",
        "\n",
        "def get_vit_config(query):\n",
        "  global VIT_CONFIG_CACHE\n",
        "  if query not in VIT_CONFIG_CACHE:\n",
        "    filename = get_vit_filename(query)\n",
        "    config = models_config.AUGREG_CONFIGS[filename.split(\"-\")[0]].copy_and_resolve_references()\n",
        "    config.unlock()\n",
        "    # Disable dropout.\n",
        "    config.transformer.dropout_rate = 0.0\n",
        "    config.transformer.attention_dropout_rate = 0.0\n",
        "    config.lock()\n",
        "    VIT_CONFIG_CACHE[query] = config\n",
        "  return VIT_CONFIG_CACHE[query].copy_and_resolve_references()\n",
        "\n",
        "def get_set_vit_config(hparams, config):\n",
        "  path_config = get_vit_config(config.vit_checkpoint_query)\n",
        "  path_config.transformer.num_layers = int(hparams[\"num_layers\"])\n",
        "  path_config.unlock()\n",
        "  path_config.num_classes = int(hparams[\"num_classes\"])\n",
        "  if \"classifier\" in hparams:\n",
        "    path_config.classifier = hparams[\"classifier\"]\n",
        "  path_config.lock()\n",
        "  path_config = FrozenConfigDict(path_config)\n",
        "  return path_config\n",
        "\n",
        "def get_max_num_layers(query):\n",
        "  config = get_vit_config(query)\n",
        "  return config.transformer.num_layers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JOV3CPHMUQbG"
      },
      "outputs": [],
      "source": [
        "# Get params from ViT checkpoints.\n",
        "def get_vit_checkpoint_comps(image_size, query):\n",
        "  filename = get_vit_filename(query)\n",
        "  config = get_vit_config(query)\n",
        "  model = models.VisionTransformer(**config, num_classes=1)  # num_classes unused.\n",
        "  init_params = copy.deepcopy(jax.device_get(\n",
        "      model.init(jax.random.PRNGKey(random.randrange(int(1e10))),\n",
        "                 VitModelFactory.get_sample_input({\"ds_image_size\": image_size}),\n",
        "                 train=False  # Disables dropout, no effect on params.\n",
        "                 )[\"params\"]))\n",
        "  params = checkpoint.load_pretrained(\n",
        "    pretrained_path=f\"gs://vit_models/augreg/{filename}.npz\",\n",
        "    init_params=init_params,\n",
        "    model_config=config)\n",
        "  return vit_model2comps(params)\n",
        "\n",
        "def get_vit_checkpoint_reshaped_posembed_component(\n",
        "    agent_id: str, ds_image_size: int, query: str):\n",
        "  params = get_vit_checkpoint_comps(ds_image_size, query)[\"posembed_input\"]\n",
        "  return Component(name=\"posembed_input\",\n",
        "                   agent_id=agent_id,\n",
        "                   params=params,\n",
        "                   train_locks=[])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FjEWZgFIUt25"
      },
      "outputs": [],
      "source": [
        "# Get ViT model and init_params.\n",
        "def get_vit_model(hparams, config):\n",
        "  vit_config = get_set_vit_config(hparams, config)\n",
        "  return models.VisionTransformer(**vit_config)\n",
        "\n",
        "def get_vit_init_comps(hparams, config):\n",
        "  model = get_vit_model(hparams, config)\n",
        "  init_params = copy.deepcopy(jax.device_get(model.init(\n",
        "      jax.random.PRNGKey(random.randrange(int(1e10))),\n",
        "      VitModelFactory.get_sample_input(hparams),\n",
        "      train=False  # Disables dropout, no effect on params.\n",
        "      )[\"params\"]))\n",
        "  return vit_model2comps(init_params)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "txU4go5GUf-e"
      },
      "outputs": [],
      "source": [
        "# ViT parameters mapping to components.\n",
        "TRANSFORMER_KEYS = set(\n",
        "    [\"encoder_norm\", \"posembed_input\" ] + \\\n",
        "    [f\"encoderblock_{k}\" for k in range(30)])\n",
        "\n",
        "def vit_model2comps(params):\n",
        "  new_params = {}\n",
        "  for k in params.keys():\n",
        "    if k == \"Transformer\":\n",
        "      t_params = params[k]\n",
        "      for t_k in t_params.keys():\n",
        "        new_params[t_k] = t_params[t_k]\n",
        "    else:\n",
        "      new_params[k] = params[k]\n",
        "  return flax.core.freeze(new_params)\n",
        "\n",
        "def vit_comps2model(params):\n",
        "  new_params = params.unfreeze()\n",
        "  new_params[\"Transformer\"] = {}\n",
        "  for k in list(new_params.keys()):\n",
        "    if k in TRANSFORMER_KEYS:\n",
        "      new_params[\"Transformer\"][k] = new_params.pop(k)\n",
        "  assert len(new_params[\"Transformer\"]) != 0\n",
        "  return flax.core.freeze(new_params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ShsZ14HoiuFD"
      },
      "source": [
        "## MultiVit Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z0-D1trkMOor"
      },
      "outputs": [],
      "source": [
        "class MultiVitModelFactory():\n",
        "  @staticmethod\n",
        "  def get_model(hparams, config):\n",
        "    return get_multivit_model(hparams, config)\n",
        "\n",
        "  @staticmethod\n",
        "  def get_init_comps(hparams, config):\n",
        "    return get_multivit_init_comps(hparams, config)\n",
        "\n",
        "  @staticmethod\n",
        "  def get_comps2model_fn():\n",
        "    return multivit_comps2model\n",
        "\n",
        "  @staticmethod\n",
        "  def get_sample_input(hparams):\n",
        "    return {str(k): get_sample_images(image_size=k, batch_size=1) for k in hparams[\"ds_image_size\"]}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "61WJRSygqedz"
      },
      "outputs": [],
      "source": [
        "def get_multivit_init_comps(hparams, config):\n",
        "  model = get_multivit_model(hparams, config)\n",
        "  init_params = copy.deepcopy(jax.device_get(\n",
        "      model.init(\n",
        "          jax.random.PRNGKey(random.randrange(int(1e10))),\n",
        "          MultiVitModelFactory.get_sample_input(hparams),\n",
        "          train=False  # Disables dropout, no effect on params.\n",
        "          )[\"params\"]))\n",
        "  return multivit_model2comps(init_params)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YeB-D_HVdIDm"
      },
      "outputs": [],
      "source": [
        "def multivit_comps2model(params):\n",
        "  params = params.unfreeze()\n",
        "  for k in params:\n",
        "    if k.startswith(\"path_\"):\n",
        "      params[k] = vit_comps2model(flax.core.freeze(params[k]))\n",
        "  return flax.core.freeze(params)\n",
        "\n",
        "def multivit_model2comps(params):\n",
        "  # Mapping of paths component skipped since those are never used from rand init.\n",
        "  return params"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "waN2hnq7kbOC"
      },
      "outputs": [],
      "source": [
        "class MultipathRouter(nn.Module):\n",
        "  init_main_path_weight: float\n",
        "  num_paths: int\n",
        "  lr_mult: float\n",
        "\n",
        "  @nn.compact\n",
        "  def __call__(self, x):\n",
        "    assert self.num_paths \u003e 0\n",
        "    assert self.lr_mult \u003e= 0 and self.lr_mult \u003c= 1\n",
        "    init_bias = np.log((1/(1/self.init_main_path_weight -1))*(self.num_paths-1))\n",
        "    x = nn.LayerNorm()(x)\n",
        "    x = nn.Dense(self.num_paths,\n",
        "                 kernel_init=nn.initializers.zeros,\n",
        "                 bias_init=nn.initializers.constant(np.asarray([init_bias]+[0]*(self.num_paths-1)))\n",
        "                 )(x)\n",
        "    x = nn.softmax(x)\n",
        "    x = self.lr_mult * x + (1-self.lr_mult) * jax.lax.stop_gradient(x)\n",
        "    return x\n",
        "\n",
        "class Connector(nn.Module):\n",
        "  out_dim: int\n",
        "\n",
        "  @nn.compact\n",
        "  def __call__(self, x):\n",
        "    x = nn.Dense(self.out_dim,\n",
        "                 kernel_init=nn.initializers.zeros,\n",
        "                 bias_init=nn.initializers.zeros,\n",
        "                 )(x)\n",
        "    return x\n",
        "\n",
        "class MultiVitModel(nn.Module):\n",
        "  config: Any\n",
        "  path_module: Type[nn.Module] = models.VisionTransformer\n",
        "  main_path_name: str = \"path_0\"\n",
        "  @nn.compact\n",
        "  def __call__(self, inputs, *, train):\n",
        "    logits_0 = self.path_module(\n",
        "        name=self.main_path_name,\n",
        "        **self.config.paths_configs[self.main_path_name])(\n",
        "            inputs[self.config.paths_image_size[self.main_path_name]], train=train)\n",
        "    out_dim = logits_0.shape[-1]\n",
        "    weights = MultipathRouter(\n",
        "        name=\"multipath_router\",\n",
        "        num_paths=len(self.config.paths_configs),\n",
        "        **self.config.router)(\n",
        "            logits_0)\n",
        "    all_logits = [logits_0]\n",
        "    for path_name in self.config.paths_configs:\n",
        "      if path_name == self.main_path_name:\n",
        "        continue\n",
        "      representation = self.path_module(\n",
        "          name=path_name,\n",
        "          **self.config.paths_configs[path_name])(\n",
        "              inputs[self.config.paths_image_size[path_name]], train=train)\n",
        "      path_logits = Connector(\n",
        "          name=f\"head_adapter_{path_name}\", out_dim=out_dim)(representation)\n",
        "      all_logits.append(path_logits)\n",
        "    stacked = jnp.stack(all_logits, axis=-1)\n",
        "    logits_comb = jnp.einsum(\"BLp,Bp-\u003eBL\", stacked, weights)\n",
        "    logits_comb = jnp.einsum(\"BLp,Bp-\u003eBL\", jax.lax.stop_gradient(stacked), weights)\n",
        "    logits_sum = jnp.einsum(\"BLp,Bp-\u003eBL\", stacked, jnp.ones_like(weights))\n",
        "    logits_out = logits_comb - jax.lax.stop_gradient(logits_sum) + logits_sum\n",
        "    return logits_out"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CC4pF48eqDLL"
      },
      "outputs": [],
      "source": [
        "def get_multivit_model(hparams, config):\n",
        "  model_config = ConfigDict()\n",
        "  model_config.paths_configs = {\n",
        "      k: get_set_vit_config(hparams[\"paths\"][k][\"hparams\"], config) for k in hparams[\"paths\"]}\n",
        "  model_config.paths_image_size = {\n",
        "      k: str(hparams[\"paths\"][k][\"hparams\"][\"ds_image_size\"]) for k in hparams[\"paths\"]}\n",
        "  model_config.router = {\n",
        "      \"init_main_path_weight\": float(hparams[\"router_init_main_path_weight\"]),\n",
        "      \"lr_mult\": float(hparams[\"router_lr_mult\"]),\n",
        "  }\n",
        "  return MultiVitModel(config=FrozenConfigDict(model_config))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lyX6WjaUDWOi"
      },
      "source": [
        "# Agents"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FNAi2xK-fARf"
      },
      "outputs": [],
      "source": [
        "def format_agent_id(class_name, task_name):\n",
        "  agent_id = f\"{class_name}/{task_name}\"\n",
        "  assert \"~\" not in agent_id, f\"Invalid agent id: {agent_id}\"\n",
        "  return agent_id.replace(\"/\", \"~\")\n",
        "\n",
        "def get_agent_class(agent_id):\n",
        "  return globals()[agent_id.split(\"~\")[0]]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TgGHyptb_BwQ"
      },
      "outputs": [],
      "source": [
        "def incremental_mutation(value, values_list):\n",
        "  assert value in values_list, f\"{value} not in {values_list}\"\n",
        "  idx = values_list.index(value)\n",
        "  idx += 1 if np.random.uniform() \u003c 0.5 else -1\n",
        "  idx = max(0, min(len(values_list)-1, idx))\n",
        "  return values_list[idx]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2AwvZWRxnao5"
      },
      "outputs": [],
      "source": [
        "DATASET_HPARAMS_KEYS_PRERFIX = \"ds_\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jCNLTP1zGu_u"
      },
      "outputs": [],
      "source": [
        "class Agent():\n",
        "  @property\n",
        "  def class_name(self):\n",
        "    return self.__class__.__name__\n",
        "\n",
        "  @property\n",
        "  def id(self):\n",
        "    return self.config.agent_id\n",
        "\n",
        "  @staticmethod\n",
        "  def get_model_factory():\n",
        "    assert False, \"Not implementd\"\n",
        "\n",
        "  def run(self):\n",
        "    assert False, \"Not implementd\"\n",
        "\n",
        "  def complete_config(self, system_state_dir, task_name, num_cycles_max):\n",
        "    self.config.system_state_dir = system_state_dir\n",
        "    self.config.task_name = task_name\n",
        "    self.config.agent_id = format_agent_id(self.class_name, task_name)\n",
        "    self.config.agent_dir = os.path.join(system_state_dir, self.id)\n",
        "    self.config.num_cycles_max = num_cycles_max\n",
        "\n",
        "  def agent_classes_to_load(self):\n",
        "    # Defaults to load only agents of the same class. Extend this list to allow\n",
        "    # to access the strucutres and parameters produced by different agent types.\n",
        "    return [self.class_name]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5j4LPnTA8EzH"
      },
      "outputs": [],
      "source": [
        "def run_cycles(agent):\n",
        "  config = agent.config\n",
        "  task_name = config.task_name\n",
        "  num_cycles = config.num_cycles_max\n",
        "  for _ in range(num_cycles):\n",
        "    agent.load_state()\n",
        "    if agent.cycle_id \u003e= num_cycles:\n",
        "      break\n",
        "    print(\"\\n\\n====\")\n",
        "    print(f\"CYCLE: [{agent.cycle_id+1}/{num_cycles}]\")\n",
        "    agent.pop.start_cycle()\n",
        "    agent_cycle(agent)\n",
        "    agent.pop.end_cycle()\n",
        "    agent.cycle_id += 1\n",
        "    agent.generation_id = 0\n",
        "    save_state(agent)\n",
        "    if agent.cycle_id \u003e= num_cycles:\n",
        "      break\n",
        "\n",
        "def run_root_model(agent):\n",
        "  agent.load_state()\n",
        "  save_state(agent)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oMUf-hLFF4Wc"
      },
      "outputs": [],
      "source": [
        "# Run a full paths sampling iteration for a task.\n",
        "def agent_cycle(agent):\n",
        "  pop = agent.pop\n",
        "  config = agent.config\n",
        "  task = Path.cached_tasks(task_name=config.task_name)\n",
        "  best_path = pop.get_best_path()\n",
        "  if TEST_IMMUTABILITY and best_path:\n",
        "    run_test_eval(best_path, test_immutability=True)\n",
        "  devices = jax.local_devices()\n",
        "  print(\"DEVICE COUNT:\", len(devices))\n",
        "  num_gen_batches = math.ceil(config.num_samples_per_cycle/len(devices))\n",
        "  for _ in range(num_gen_batches):\n",
        "    if agent.generation_id \u003e= num_gen_batches:\n",
        "      break\n",
        "    print(f\"----\\nGENERATION: [{agent.generation_id+1}/{num_gen_batches}]\")\n",
        "    ds_hparams = agent.sample_ds_hparams()\n",
        "    ds_hparams[\"num_classes\"] = task.num_classes\n",
        "    paths = []\n",
        "    for i in range(len(devices)):\n",
        "      print(f\"Sampling path {Path.counter}\")\n",
        "      paths.append(agent.sample_path(ds_hparams))\n",
        "      gc.collect()\n",
        "    ds_hparams = agent.finalize_ds_hparams(ds_hparams, paths)\n",
        "    ds_train = task.get_ds(\"train\", ds_hparams)\n",
        "    ds_validation = task.get_ds(\"validation\", ds_hparams)\n",
        "    train_loop(paths, ds_train, ds_validation, devices, config)\n",
        "    for path in paths:\n",
        "      path.metrics[\"generation_id\"] = agent.generation_id\n",
        "      if path.metrics[\"improved\"]:\n",
        "        assert path not in pop.paths[config.agent_id]\n",
        "        pop.paths[config.agent_id].append(path)\n",
        "    pop.prune_population()\n",
        "    # Track best path.\n",
        "    curr_best_path = pop.get_best_path()\n",
        "    if curr_best_path != best_path:\n",
        "      if best_path:\n",
        "        assert curr_best_path.score() \u003e= best_path.score()\n",
        "      best_path = curr_best_path\n",
        "      best_path.metrics[\"new_best\"] = True\n",
        "      agent.print_best_path_summary()\n",
        "    agent.generation_id += 1\n",
        "    df_leaderboard(pop_to_df(pop))\n",
        "    if agent.generation_id \u003c num_gen_batches:\n",
        "      save_state(agent)\n",
        "  assert best_path in pop.paths[config.agent_id], best_path\n",
        "  run_test_eval(best_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9I56DcbI79M2"
      },
      "source": [
        "## Vit Agent"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "njMI52JMnqq1"
      },
      "outputs": [],
      "source": [
        "def get_common_config_vit():\n",
        "  config = ConfigDict()\n",
        "  config.num_train_examples_between_validations_max = 300_000\n",
        "  config.num_validations_per_path_training = 4\n",
        "  config.num_validation_examples_max = 10_000\n",
        "  config.num_samples_per_cycle = 16\n",
        "  config.max_task_population_size = 5\n",
        "  # Force finetune last layer norm that technically is part of the head.\n",
        "  config.force_mutations = [\"clone:encoder_norm\"]\n",
        "  config.scorer_kwargs = dict(\n",
        "      scale_factor=0.99,\n",
        "      base_accounted_params=2_200_000_000,\n",
        "      base_flops=3_800_000_000_000,\n",
        "      )\n",
        "  config.hparams_defaults = {\n",
        "      \"_mu_\": 0.2,\n",
        "      \"opt_lr\": 0.02,\n",
        "      \"opt_lr_schedule\": \"cosine\",\n",
        "      \"opt_lr_warmup_ratio\": 0.02,\n",
        "      \"opt_momentum\": 0.8,\n",
        "      \"opt_nesterov\": True,\n",
        "      \"ds_area_range_min\": 1.0,\n",
        "      \"ds_aspect_ratio_range_min\": 1.0,\n",
        "      \"ds_flip_left_right\": False,\n",
        "      \"ds_brightness_delta\": 0.0,\n",
        "      \"ds_contrast_delta\": 0.0,\n",
        "      \"ds_saturation_delta\": 0.0,\n",
        "      \"ds_hue_delta\": 0.0,\n",
        "      \"ds_quality_delta\": 0.0,\n",
        "  }\n",
        "  config.hparams_mutation_ranges = {\n",
        "      \"_mu_\": [0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18, 0.20, 0.22, 0.24, 0.26, 0.28, 0.30],\n",
        "      \"opt_lr\": [0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0],\n",
        "      \"opt_lr_warmup_ratio\": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2, 0.3],\n",
        "      \"opt_momentum\": [0.5, 0.6, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.98, 0.99],\n",
        "      \"opt_nesterov\": [True, False],\n",
        "      \"ds_area_range_min\": [0.05, 0.5, 0.95, 1.0],\n",
        "      \"ds_aspect_ratio_range_min\": [0.5, 0.75, 1.0],\n",
        "      \"ds_flip_left_right\": [True, False],\n",
        "      \"ds_brightness_delta\": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2],\n",
        "      \"ds_contrast_delta\": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2],\n",
        "      \"ds_saturation_delta\": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2],\n",
        "      \"ds_hue_delta\": [0.0, 0.01, 0.02, 0.05, 0.1, 0.2],\n",
        "      \"ds_quality_delta\": [ 0.0, 0.01, 0.02, 0.05, 0.1, 0.2],\n",
        "  }\n",
        "  return config\n",
        "\n",
        "def get_config_vit_large():\n",
        "  config = get_common_config_vit()\n",
        "  config.batch_size = 16\n",
        "  # The query is used to get the model configs even if the checkpoint is not loaded.\n",
        "  config.vit_checkpoint_query = 'name==\"L/16\" and ds==\"i21k\" and aug==\"medium2\" and wd==0.03 and sd==0.1'\n",
        "  max_num_layers = get_max_num_layers(config.vit_checkpoint_query)\n",
        "  config.hparams_defaults[\"num_layers\"] = max_num_layers\n",
        "  config.hparams_mutation_ranges[\"num_layers\"] = list(\n",
        "      range(config.hparams_defaults[\"num_layers\"]+1\n",
        "            +1  # Allow to exceed root-model's layers by 1.\n",
        "            ))\n",
        "  config.hparams_defaults[\"ds_image_size\"] = 384\n",
        "  config.hparams_mutation_ranges[\"ds_image_size\"] = [224, 384]\n",
        "  return config\n",
        "\n",
        "def get_config_vit_ti3():\n",
        "  config = get_common_config_vit()\n",
        "  config.batch_size = 512\n",
        "  config.vit_checkpoint_query = 'name==\"Ti/16\" and ds==\"i21k\" and aug==\"light1\" and wd==0.1 and sd==0.0'\n",
        "  config.hparams_defaults[\"num_layers\"] = 3\n",
        "  config.hparams_mutation_ranges[\"num_layers\"] = list(range(config.hparams_defaults[\"num_layers\"]+1))\n",
        "  config.hparams_defaults[\"ds_image_size\"] = 32\n",
        "  config.hparams_mutation_ranges[\"ds_image_size\"] = [16*i for i in (range(1, 1+int(112/16)))]\n",
        "  return config\n",
        "\n",
        "def get_config_vit_base():\n",
        "  config = get_common_config_vit()\n",
        "  config.batch_size = 256\n",
        "  config.vit_checkpoint_query = 'name==\"B/16\" and ds==\"i21k\" and aug==\"medium1\" and wd==0.1 and sd==0'\n",
        "  max_num_layers = get_max_num_layers(config.vit_checkpoint_query)\n",
        "  config.hparams_defaults[\"num_layers\"] = max_num_layers\n",
        "  config.hparams_mutation_ranges[\"num_layers\"] = list(range(config.hparams_defaults[\"num_layers\"]+1))\n",
        "  config.hparams_defaults[\"ds_image_size\"] = 80\n",
        "  config.hparams_mutation_ranges[\"ds_image_size\"] = [16*i for i in (range(1, 1+int(112/16)))]\n",
        "  return config\n",
        "\n",
        "def config_validate(config):\n",
        "  for khp in config.hparams_defaults:\n",
        "    if khp in config.hparams_mutation_ranges:\n",
        "      assert config.hparams_defaults[khp] in config.hparams_mutation_ranges[khp], khp\n",
        "  for khp in config.hparams_mutation_ranges:\n",
        "    assert khp in config.hparams_defaults, khp"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_wzXhNikODR8"
      },
      "outputs": [],
      "source": [
        "class Vit(Agent):\n",
        "  \"\"\"ViT large\"\"\"\n",
        "  def __init__(self, system_state_dir, task_name, num_cycles_max):\n",
        "    self.config = self.get_config()\n",
        "    self.complete_config(system_state_dir, task_name, num_cycles_max)\n",
        "    self.cached_posembed_components = ObjectCache(get_vit_checkpoint_reshaped_posembed_component)\n",
        "\n",
        "  @staticmethod\n",
        "  def get_model_factory():\n",
        "    return VitModelFactory\n",
        "\n",
        "  def load_state(self):\n",
        "    task_name = self.config.task_name\n",
        "    self.pop = Population(self.config)\n",
        "    self.cycle_id = 0\n",
        "    self.generation_id = 0\n",
        "    # Root models.\n",
        "    if task_name.startswith(\"root_model/\"):\n",
        "      hparams = self.config.hparams_defaults.as_configdict()\n",
        "      if task_name == \"root_model/random_init\":\n",
        "        hparams[\"num_classes\"] = 0  # Removes head layer.\n",
        "        path_params = self.get_model_factory().get_init_comps(hparams, self.config)\n",
        "      else:\n",
        "        assert task_name == \"root_model/checkpoint\", task_name\n",
        "        path_params = get_vit_checkpoint_comps(\n",
        "            hparams[\"ds_image_size\"],\n",
        "            self.config.vit_checkpoint_query)\n",
        "      path = Path(\n",
        "          hparams,\n",
        "          params2comps(path_params, train_locks=[self.id], agent_id=self.id),\n",
        "          parent=None,\n",
        "          agent_id=self.id,\n",
        "          task_name=task_name)\n",
        "      self.pop.paths[self.id].append(path)\n",
        "      return\n",
        "\n",
        "    # Load latest agent state.\n",
        "    def validate_df(df):\n",
        "      assert len(df[\"agent_id\"].unique()) == 1, len(df[\"agent_id\"].unique())\n",
        "      assert df[\"agent_id\"].unique()[0] == self.id, df[\"agent_id\"].unique()[0]\n",
        "    agent_checkpoint = latest_checkpoint(\n",
        "        os.path.join(self.config.agent_dir, \"state_*_*/\"))\n",
        "    if agent_checkpoint:\n",
        "      matched = re.findall(r\"checkpoint_([0-9]+)_([0-9]+)$\", agent_checkpoint)\n",
        "      assert len(matched) == 1\n",
        "      self.cycle_id = int(matched[0][0])\n",
        "      self.generation_id = int(matched[0][1])\n",
        "      state_dir = os.path.dirname(agent_checkpoint)\n",
        "      self.pop.paths_df = df_read_from_csv(state_dir, \"paths\")\n",
        "      self.pop.comps_df = df_read_from_csv(state_dir, \"components\")\n",
        "      validate_df(self.pop.paths_df)\n",
        "      validate_df(self.pop.comps_df)\n",
        "      # Set globals.\n",
        "      Path.paths = []\n",
        "      Path.counter = 1 + int(self.pop.paths_df.id.max())\n",
        "      Component.counter = 1 + int(self.pop.comps_df.id.max())\n",
        "      # Get id of the last componet saved in a non intermediate checkpoint.\n",
        "      non_intermediated_checkpoint = latest_checkpoint(\n",
        "          os.path.join(self.config.agent_dir, \"state_*_0/\"))\n",
        "      if non_intermediated_checkpoint:\n",
        "        ni_paths_df = df_read_from_csv(\n",
        "            os.path.dirname(non_intermediated_checkpoint), \"paths\")\n",
        "        validate_df(ni_paths_df)\n",
        "        Path.last_saved = int(ni_paths_df.id.max())\n",
        "        ni_comps_df = df_read_from_csv(\n",
        "            os.path.dirname(non_intermediated_checkpoint), \"components\")\n",
        "        validate_df(ni_comps_df)\n",
        "        Component.last_saved = int(ni_comps_df.id.max())\n",
        "      print(\"CONTINUING FROM STATE\", self.cycle_id, self.generation_id)\n",
        "\n",
        "    # Load all available paths.\n",
        "    all_agents_dirs = []\n",
        "    for agent_class_to_load in self.agent_classes_to_load():\n",
        "      all_agents_dirs.extend(\n",
        "          gfile.glob(os.path.join(self.config.system_state_dir,\n",
        "                                  agent_class_to_load+\"~*\")))\n",
        "    assert all_agents_dirs, f\"No state for agents: {self.agent_classes_to_load()}\"\n",
        "    state_dir = os.path.dirname(agent_checkpoint) if agent_checkpoint else None\n",
        "    load_paths(self.pop, state_dir, all_agents_dirs)\n",
        "\n",
        "    assert self.pop.paths, \"Population is empty, run an agent creating a \" \\\n",
        "        \"root model to initialize the population.\"\n",
        "    df_leaderboard(pop_to_df(self.pop))\n",
        "\n",
        "  def get_config(self):\n",
        "    return get_config_vit_large()\n",
        "\n",
        "  def run(self):\n",
        "    if self.config.task_name.startswith(\"root_model/\"):\n",
        "      run_root_model(self)\n",
        "      return\n",
        "    run_cycles(self)\n",
        "\n",
        "  def complete_config(self, system_state_dir, task_name, num_cycles_max):\n",
        "    super().complete_config(system_state_dir, task_name, num_cycles_max)\n",
        "    self.config = FrozenConfigDict(self.config)\n",
        "    config_validate(self.config)\n",
        "\n",
        "  def do_mutate(self, hparams, mutation_name):\n",
        "    \"\"\"Returns True if mutation is sampled to be applied.\"\"\"\n",
        "    if mutation_name in self.config.get(\"force_mutations\", []):\n",
        "      return True\n",
        "    mutation_prob_k = f\"_mu_|{mutation_name}\"\n",
        "    # Fallback is used for batch shared sampling.\n",
        "    mu = hparams.get(\"_mu_\", self.config.hparams_defaults[\"_mu_\"])\n",
        "    mutation_prob = hparams.get(mutation_prob_k, mu)\n",
        "    if \"_mu_\" in self.config.hparams_mutation_ranges:\n",
        "      if mu \u003e np.random.uniform():\n",
        "        mutation_prob = incremental_mutation(\n",
        "            mutation_prob, self.config.hparams_mutation_ranges[\"_mu_\"])\n",
        "      hparams[mutation_prob_k] = mutation_prob\n",
        "    return mutation_prob \u003e np.random.uniform()\n",
        "\n",
        "  def parent_decay_selection(self):\n",
        "    for path in sorted(self.pop.paths[self.config.agent_id],\n",
        "                       key=lambda p: p.score(),\n",
        "                       reverse=True):\n",
        "      offsprings = path.metrics.get(\"offsprings\", 0)\n",
        "      assert not math.isnan(offsprings)\n",
        "      select_prob = 0.5 ** offsprings\n",
        "      print(f\" Candidate parent path {path.id},\",\n",
        "            f\"selection probability: 0.5^{offsprings} == {select_prob}\")\n",
        "      if np.random.uniform() \u003c select_prob:\n",
        "        path.metrics[\"offsprings\"] = path.metrics.get(\"offsprings\", 0) + 1\n",
        "        return path\n",
        "    return None\n",
        "\n",
        "  def sample_path(self, ds_hparams):\n",
        "    parent = self.parent_decay_selection()\n",
        "    if not parent:  # Random sample.\n",
        "      parent = random.choice([p for paths in self.pop.paths.values() for p in paths])\n",
        "      print(f\" Randomly selected parent {parent.agent_id}:{parent.id}\")\n",
        "    return self.mutate_parent(parent, ds_hparams)\n",
        "\n",
        "  def mutate_hparams(self, hparams):\n",
        "    for k in sorted(self.config.hparams_mutation_ranges):\n",
        "      if k in hparams and self.do_mutate(hparams, f\"hp:{k}\"):\n",
        "        hparams[k] = incremental_mutation(\n",
        "            hparams[k], self.config.hparams_mutation_ranges[k])\n",
        "    return hparams\n",
        "\n",
        "  def sample_ds_hparams(self):\n",
        "    \"\"\"Sample hparams that need to be shared across each paths generation.\"\"\"\n",
        "    ds_hparams = {}\n",
        "    # Initialize shared hparams with defaults.\n",
        "    for key in self.config.hparams_defaults:\n",
        "      if key.startswith(DATASET_HPARAMS_KEYS_PRERFIX):\n",
        "        ds_hparams[key] = self.config.hparams_defaults[key]\n",
        "    # Overwrite with values from best path if available.\n",
        "    best_path = self.pop.get_best_path()\n",
        "    if best_path:\n",
        "      ds_hparams.update(\n",
        "          {k : best_path.hparams[k] for k in ds_hparams if k in best_path.hparams})\n",
        "      ds_hparams.update(\n",
        "          {k : best_path.hparams[k] for k in best_path.hparams if k.startswith(\n",
        "              f\"_mu_|hp:{DATASET_HPARAMS_KEYS_PRERFIX}\")})\n",
        "      # Sample mutations.\n",
        "      df_hparams = self.mutate_hparams(ds_hparams)\n",
        "    # Validate.\n",
        "    for k in ds_hparams:\n",
        "      assert (k.startswith(DATASET_HPARAMS_KEYS_PRERFIX) or\n",
        "              k.startswith(f\"_mu_|hp:{DATASET_HPARAMS_KEYS_PRERFIX}\"))\n",
        "    return ds_hparams\n",
        "\n",
        "  def finalize_ds_hparams(self, ds_hparams, paths):\n",
        "    # Validate shared params.\n",
        "    for k in ds_hparams:\n",
        "      if k.startswith(DATASET_HPARAMS_KEYS_PRERFIX):\n",
        "        for path in paths:\n",
        "          assert ds_hparams[k] == path.hparams[k]\n",
        "    return ds_hparams\n",
        "\n",
        "  def mutate_parent(self, parent, ds_hparams):\n",
        "    config = self.config\n",
        "    agent_id = config.agent_id\n",
        "    task_name = config.task_name\n",
        "    comps = []\n",
        "    new_hparams = copy.deepcopy(parent.hparams)\n",
        "    new_hparams = self.mutate_hparams(new_hparams)\n",
        "    # Overwrite dataset hparams with those sampled for the generation batch.\n",
        "    new_hparams.update(ds_hparams)\n",
        "\n",
        "    def get_component_ref(c, clone):\n",
        "      if c.is_trainable() or clone:\n",
        "        # Clone trainable component.\n",
        "        return c.clone(agent_id=agent_id)\n",
        "      # Refer to frozen component.\n",
        "      return c\n",
        "\n",
        "    init_params = self.get_model_factory().get_init_comps(new_hparams, config)\n",
        "    for new_comp_name in init_params:\n",
        "      comp = None\n",
        "      # Attept to reuse matching componenent from closer ancestor.\n",
        "      ancestor = parent\n",
        "      while ancestor is not None:\n",
        "        comps_lookup = {c.name:c for c in ancestor.components}\n",
        "        if new_comp_name in comps_lookup:\n",
        "          # Head must be trainable if no acestor is of same agent will fall back\n",
        "          # to random init of correct shape.\n",
        "          if new_comp_name == \"head\" and agent_id != ancestor.agent_id:\n",
        "            assert agent_id != ancestor.agent_id, f\"{agent_id} != {ancestor.agent_id}\"\n",
        "            ancestor = ancestor.parent\n",
        "            continue\n",
        "          # Check shapes match otherwise skip.\n",
        "          if (jax.tree_util.tree_map(jnp.shape, init_params[new_comp_name]) !=\n",
        "              jax.tree_util.tree_map(jnp.shape, comps_lookup[new_comp_name].params)):\n",
        "            if new_comp_name == \"posembed_input\":\n",
        "              # Change of image size changed shape of position embeddings,\n",
        "              # this can happend if ds_image_size is tuned,\n",
        "              # continue searching through ancestors for matching size.\n",
        "              assert \"ds_image_size\" in config.hparams_mutation_ranges\n",
        "              assert new_hparams[\"ds_image_size\"] != ancestor.hparams[\"ds_image_size\"]\n",
        "              ancestor = ancestor.parent\n",
        "              continue\n",
        "\n",
        "            print(f\"WARNING: Shapes do not match for component: {new_comp_name}  {ancestor.agent_id}-\u003e{agent_id}\")\n",
        "            print(jax.tree_util.tree_map(jnp.shape, init_params[new_comp_name]))\n",
        "            print(jax.tree_util.tree_map(jnp.shape, comps_lookup[new_comp_name].params))\n",
        "            assert False  # Should not happen in current configuration.\n",
        "\n",
        "          ancestor_comp = comps_lookup[new_comp_name]\n",
        "          comp = get_component_ref(\n",
        "              ancestor_comp, clone=(\n",
        "                  ancestor_comp.is_trainable() or self.do_mutate(\n",
        "                      new_hparams, f\"clone:{new_comp_name}\")))\n",
        "          break\n",
        "        ancestor = ancestor.parent\n",
        "      # Get reshaped posembed_input from checkpoint.\n",
        "      if comp is None and new_comp_name == \"posembed_input\":\n",
        "        pe_comp = self.cached_posembed_components(\n",
        "            agent_id=agent_id,\n",
        "            query=config.vit_checkpoint_query,\n",
        "            **new_hparams)\n",
        "        # Clone to make the component trainable.\n",
        "        comp = get_component_ref(pe_comp, clone=True)\n",
        "      # Otherwise create one from random init params.\n",
        "      if comp is None:\n",
        "        # Possible rand init triggering combinations in current configurations.\n",
        "        assert (\n",
        "            new_comp_name == \"head\"\n",
        "            or (new_comp_name.startswith(\"encoderblock_\")\n",
        "                and config.hparams_defaults[\"num_layers\"] \u003c max(\n",
        "                config.hparams_mutation_ranges.get(\"num_layers\", [-1]))))\n",
        "        comp = params2comps(\n",
        "            init_params, train_locks=[],\n",
        "            agent_id=agent_id, name=new_comp_name)[0]\n",
        "      assert comp is not None\n",
        "      comps.append(comp)\n",
        "    return Path(new_hparams, comps, parent=parent, agent_id=agent_id, task_name=task_name)\n",
        "\n",
        "  def print_best_path_summary(self):\n",
        "    best_path = self.pop.get_best_path()\n",
        "    print(f\"Best id:{best_path.id}\",\n",
        "          f\"score:{best_path.score():.4f}\",\n",
        "          f\"quality:{best_path.metrics['quality']:.4f}\",\n",
        "          f\"gen:{best_path.metrics['generation_id']}\",\n",
        "          f\"\\n{best_path.hparams}\")\n",
        "\n",
        "  def get_paths_to_publish(self):\n",
        "    return [p for p in self.pop.paths[self.config.agent_id]]\n",
        "\n",
        "class VitT3(Vit):\n",
        "  \"\"\"ViT tiny 3 layers\"\"\"\n",
        "  def get_config(self):\n",
        "    return get_config_vit_ti3()\n",
        "\n",
        "class VitB(Vit):\n",
        "  \"\"\"ViT base\"\"\"\n",
        "  def get_config(self):\n",
        "    return get_config_vit_base()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-N57gaotPN55"
      },
      "source": [
        "## MultiVit Agent"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Cb3uj_kw9ED9"
      },
      "outputs": [],
      "source": [
        "def set_multivit_common_config(config):\n",
        "  config.force_mutations = []\n",
        "  config.scorer_kwargs = {}\n",
        "  for rm_k in [\"num_layers\", \"ds_image_size\"]:\n",
        "    del config.hparams_defaults[rm_k]\n",
        "    del config.hparams_mutation_ranges[rm_k]\n",
        "  config.hparams_defaults[\"router_init_main_path_weight\"] = 0.8\n",
        "  config.hparams_defaults[\"router_lr_mult\"] = 0.05\n",
        "  config.hparams_defaults[\"num_paths\"] = 2\n",
        "  config.hparams_mutation_ranges[\"router_lr_mult\"] = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1]\n",
        "  config.hparams_mutation_ranges[\"num_paths\"] = [2, 3]\n",
        "  return config"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3Id10GxBcyJi"
      },
      "outputs": [],
      "source": [
        "class MultiVit(Vit):\n",
        "  def get_config(self):\n",
        "    config = get_config_vit_large()\n",
        "    return set_multivit_common_config(config)\n",
        "\n",
        "  @staticmethod\n",
        "  def get_model_factory():\n",
        "    return MultiVitModelFactory\n",
        "\n",
        "  def run(self):\n",
        "    run_cycles(self)\n",
        "\n",
        "  def agent_classes_to_load(self):\n",
        "    return [self.single_path_agent_class()]\n",
        "\n",
        "  def single_path_agent_class(self):\n",
        "    return self.class_name.replace(\"Multi\", \"\")\n",
        "\n",
        "  def single_path_main_agent_id(self):\n",
        "    return format_agent_id(self.single_path_agent_class(), self.config.task_name)\n",
        "\n",
        "  def load_state(self):\n",
        "    if self.config.task_name.startswith(\"root_model/\"):\n",
        "      assert False, (\n",
        "          \"Root models need to be generated with the corresponding\" \\\n",
        "          f\"single path agent: '{self.single_path_agent_class()}'.\")\n",
        "    super().load_state()\n",
        "    assert self.single_path_main_agent_id() in self.pop.paths, (\n",
        "        \"Missing state for the corresponding single path main agent. \" \\\n",
        "        f\"Run agent '{self.single_path_agent_class()}' \" \\\n",
        "        f\"on the '{self.config.task_name}' task to generated it.\")\n",
        "\n",
        "  def sample_path(self, ds_hparams):\n",
        "    selected_paths = {}\n",
        "    selected_paths[\"path_0\"] = random.choice(self.pop.paths[self.single_path_main_agent_id()])\n",
        "    parent = self.parent_decay_selection()\n",
        "    if parent is not None:\n",
        "      new_hparams = copy.deepcopy(parent.hparams)\n",
        "      print(\" Selected\")\n",
        "    else:\n",
        "      parent = selected_paths[\"path_0\"]\n",
        "      best_path = self.pop.get_best_path()\n",
        "      if best_path:\n",
        "        new_hparams = copy.deepcopy(best_path.hparams)\n",
        "      else:\n",
        "        new_hparams = self.config.hparams_defaults.to_dict()\n",
        "      print(\" New random\")\n",
        "    new_hparams[\"paths\"] = {}\n",
        "    new_hparams = self.mutate_hparams(new_hparams)\n",
        "    # Overwrite dataset hparams with those sampled for the generation batch.\n",
        "    new_hparams.update(ds_hparams)\n",
        "\n",
        "    for path_name in [f\"path_{i}\" for i in range(int(new_hparams[\"num_paths\"]))]:\n",
        "      if path_name in selected_paths:\n",
        "        continue\n",
        "      if path_name in parent.hparams.get(\"paths\", {}):\n",
        "        selected_paths[path_name] = self.pop.get_path_from_full_id(parent.hparams[\"paths\"][path_name][\"agent_id\"], parent.hparams[\"paths\"][path_name][\"id\"])\n",
        "        print(\" Subpath from parent: \", path_name, selected_paths[path_name].full_id)\n",
        "        continue\n",
        "      selected_paths[path_name] = random.choice([\n",
        "          p for paths in self.pop.paths.values() for p in paths if (\n",
        "              p.agent_id not in [sp.agent_id for sp in selected_paths.values()]\n",
        "              and p.agent_id.startswith(f\"{self.single_path_agent_class()}~\")\n",
        "              and not p.agent_id.endswith(\"~1k\")  # Excludes VTAB-1k tasks.\n",
        "              # and (path_name != \"path_1\" or p.agent_id in ['Vit~i_naturalist2017'])  # Forces i_naturalist2017 selection.\n",
        "              )])\n",
        "      print(\" Subpath rand selected:\", path_name, selected_paths[path_name].full_id)\n",
        "\n",
        "    for (path_name, path) in selected_paths.items():\n",
        "      new_hparams[\"paths\"][path_name] = {\n",
        "          \"id\": path.id,\n",
        "          \"agent_id\": path.agent_id,\n",
        "          \"hparams\": copy.deepcopy(path.hparams)}\n",
        "    print(\" Sampled subpaths:\",\n",
        "          {k: v[\"agent_id\"] for k, v in new_hparams[\"paths\"].items()})\n",
        "    # Set headless model config for models of different tasks.\n",
        "    for k in new_hparams[\"paths\"]:\n",
        "      if selected_paths[k].task_name != self.config.task_name:\n",
        "        new_hparams[\"paths\"][k][\"hparams\"][\"num_classes\"] = 0\n",
        "    # Collect image sizes needed.\n",
        "    image_sizes = set()\n",
        "    for k in new_hparams[\"paths\"]:\n",
        "      image_sizes.add(int(new_hparams[\"paths\"][k][\"hparams\"][\"ds_image_size\"]))\n",
        "    new_hparams[\"ds_image_size\"] = list(image_sizes)\n",
        "    # Collect components.\n",
        "    init_params = self.get_model_factory().get_init_comps(new_hparams, self.config)\n",
        "    comps = []\n",
        "    for new_comp_name in init_params:\n",
        "      if new_comp_name in new_hparams[\"paths\"]:\n",
        "        comps.append(ComponentPath(name=new_comp_name,\n",
        "                                   path=selected_paths[new_comp_name]))\n",
        "      else:\n",
        "        comps_lookup = {c.name:c for c in parent.components}\n",
        "        if new_comp_name in comps_lookup and (\n",
        "          jax.tree_util.tree_map(jnp.shape, init_params[new_comp_name]) ==\n",
        "          jax.tree_util.tree_map(jnp.shape, comps_lookup[new_comp_name].params)):\n",
        "          print(\" COMP Reusing\", new_comp_name)\n",
        "          comp = comps_lookup[new_comp_name].clone(agent_id=self.config.agent_id)\n",
        "        else:\n",
        "          print(\" COMP Init\", new_comp_name)\n",
        "          comp = Component(name=new_comp_name,\n",
        "                           agent_id=self.config.agent_id,\n",
        "                           params=init_params[new_comp_name],\n",
        "                           train_locks=[])\n",
        "        comps.append(comp)\n",
        "    return Path(new_hparams, comps, parent=parent,\n",
        "                agent_id=self.config.agent_id, task_name=self.config.task_name)\n",
        "\n",
        "  def finalize_ds_hparams(self, ds_hparams, paths):\n",
        "    # Validate shared params.\n",
        "    for k in ds_hparams:\n",
        "      if k.startswith(DATASET_HPARAMS_KEYS_PRERFIX):\n",
        "        for path in paths:\n",
        "          assert ds_hparams[k] == path.hparams[k], (k, ds_hparams[k], path.hparams[k])\n",
        "    image_sizes = set()\n",
        "    for path in paths:\n",
        "      image_sizes.update(path.hparams[\"ds_image_size\"])\n",
        "    ds_hparams[\"ds_image_size\"] = list(image_sizes)\n",
        "    print(\"Image sizes:\", ds_hparams[\"ds_image_size\"])\n",
        "    return ds_hparams\n",
        "\n",
        "  def print_best_path_summary(self):\n",
        "    super().print_best_path_summary()\n",
        "    best_path = self.pop.get_best_path()\n",
        "    print(\"Paths used by best model:\",\n",
        "          {k: best_path.hparams[\"paths\"][k][\"agent_id\"] for k in best_path.hparams[\"paths\"]})\n",
        "\n",
        "  def get_paths_to_publish(self):\n",
        "    paths_to_publish = [p for p in self.pop.paths[self.config.agent_id]]\n",
        "    for path in list(paths_to_publish):\n",
        "      for k in path.hparams[\"paths\"]:\n",
        "        paths_to_publish.append(self.pop.get_path_from_full_id(path.hparams[\"paths\"][k][\"agent_id\"], path.hparams[\"paths\"][k][\"id\"]))\n",
        "    return paths_to_publish\n",
        "\n",
        "class MultiVitT3(MultiVit):\n",
        "  def get_config(self):\n",
        "    config = get_config_vit_ti3()\n",
        "    return set_multivit_common_config(config)\n",
        "\n",
        "class MultiVitB(MultiVit):\n",
        "  def get_config(self):\n",
        "    config = get_config_vit_base()\n",
        "    return set_multivit_common_config(config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gfGafBjhDpV7"
      },
      "source": [
        "# Tasks"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mBAsZx4IWWaS"
      },
      "outputs": [],
      "source": [
        "TFDS_BUILDERS_CACHE = {}\n",
        "def get_tfds_builder(tfds_name):\n",
        "  global TFDS_BUILDERS_CACHE\n",
        "  if tfds_name not in TFDS_BUILDERS_CACHE:\n",
        "    TFDS_BUILDERS_CACHE[tfds_name] = tfds.builder(tfds_name)\n",
        "    TFDS_BUILDERS_CACHE[tfds_name].download_and_prepare()\n",
        "  return TFDS_BUILDERS_CACHE[tfds_name]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kqnhOnk5jbnm"
      },
      "outputs": [],
      "source": [
        "def get_default_splits(tfds_name):\n",
        "  info = get_tfds_builder(tfds_name).info\n",
        "  splits = list(info.splits.keys())\n",
        "  assert \"train\" in splits, splits\n",
        "  splits.remove(\"train\")\n",
        "  used_percent = 0\n",
        "  slice_percent = 5\n",
        "  pp = {}\n",
        "  for k in [\"test\", \"validation\"]:\n",
        "    if k in splits:\n",
        "      pp[k] = k\n",
        "      splits.remove(k)\n",
        "    else:\n",
        "      pp[k] = f\"train[{used_percent}%:{used_percent+slice_percent}%]\"\n",
        "      used_percent += slice_percent\n",
        "  pp[\"train\"] = f\"train[{used_percent}%:]\"\n",
        "  return pp\n",
        "\n",
        "def get_dataset_and_splits(tfds_name: str):\n",
        "  vtab_class = None\n",
        "  if tfds_name in [\"imagenet_v2\", \"cifar10_1\"]:\n",
        "    assert False,  f\"{tfds_name} used as validation set for other tasks.\"\n",
        "\n",
        "  if tfds_name == \"imagenet2012\":\n",
        "    dataset = {\n",
        "        \"train\":\"imagenet2012\", \"validation\":\"imagenet_v2\", \"test\":\"imagenet2012\"}\n",
        "    splits = {\n",
        "        \"train\":\"train\", \"validation\":\"test\", \"test\":\"validation\"}\n",
        "  elif tfds_name == \"cifar100\":\n",
        "    dataset = tfds_name\n",
        "    splits = {\n",
        "        \"train\":\"train[:98%]\", \"validation\":\"train[98%:]\", \"test\":\"test\"}\n",
        "  elif tfds_name == \"cifar10\":\n",
        "    dataset = {\n",
        "        \"train\":\"cifar10\", \"validation\":\"cifar10_1\", \"test\":\"cifar10\"}\n",
        "    splits = {\n",
        "        \"train\":\"train\", \"validation\":\"test\", \"test\":\"test\"}\n",
        "  elif (tfds_name.startswith(\"visual_domain_decathlon/\") or\n",
        "        tfds_name in [\"i_naturalist2017\", \"i_naturalist2018\", \"places365_small\"]):\n",
        "    dataset = tfds_name\n",
        "    # Test has no labels, split validation in half.\n",
        "    splits =  {\n",
        "        \"train\":\"train\", \"validation\":\"validation[:50%]\", \"test\":\"validation[50%:]\"}\n",
        "  elif tfds_name.startswith(\"cmaterdb/\"):\n",
        "    dataset = tfds_name\n",
        "    # Increase size of validation set due to small dataset size.\n",
        "    splits =  {\n",
        "        \"train\":\"train[20%:]\", \"validation\":\"train[:20%]\", \"test\":\"test\"}\n",
        "  elif tfds_name == \"omniglot\":\n",
        "    # Test has no labels, and missing validation, use additional splits.\n",
        "    dataset = tfds_name\n",
        "    splits = {\"train\":\"train\", \"validation\":\"small1\", \"test\":\"small2\"}\n",
        "  elif tfds_name.startswith(\"controlled_noisy_web_labels/\"):\n",
        "    dataset = tfds_name\n",
        "    splits =  {\n",
        "        \"train\":\"train_00\",\n",
        "        \"validation\":\"validation[:50%]\",\n",
        "        \"test\":\"validation[50%:]\"}\n",
        "  elif tfds_name.startswith(\"cycle_gan/\"):\n",
        "    dataset = tfds_name\n",
        "    splits =  {\n",
        "        \"train\":\"trainA[10%:]+trainB[10%:]\",\n",
        "        \"validation\":\"trainA[:10%]+trainB[:10%]\",\n",
        "        \"test\":\"testA+testB\"}\n",
        "  elif tfds_name in [\"imagenet_a\", \"imagenet_r\", \"imagenet_sketch\",\n",
        "                     \"siscore/rotation\", \"siscore/size\", \"siscore/location\",]:\n",
        "    # Only test split.\n",
        "    dataset = tfds_name\n",
        "    splits =  {\n",
        "        \"train\":\"test[10%:]\",\n",
        "        \"validation\":\"test[5%:10%]\",\n",
        "        \"test\":\"test[:5%]\"}\n",
        "  elif tfds_name in [\"pet_finder\"]:\n",
        "    # Explicitly use only train split. E.g. test has no labels.\n",
        "    dataset = tfds_name\n",
        "    splits =  {\n",
        "        \"train\":\"train[10%:]\",\n",
        "        \"validation\":\"train[5%:10%]\",\n",
        "        \"test\":\"train[:5%]\"}\n",
        "  elif tfds_name == \"quickdraw_bitmap\":\n",
        "    dataset = tfds_name\n",
        "    # Cap size of test and validation set.\n",
        "    splits =  {\n",
        "        \"train\":\"train[20000:]\", \"validation\":\"train[10000:20000]\", \"test\":\"train[:10000]\"}\n",
        "  elif tfds_name == \"stanford_online_products\":\n",
        "    dataset = tfds_name\n",
        "    # Use the first 10k test samples as validation since test has 60k.\n",
        "    splits =  {\n",
        "        \"train\":\"train\", \"validation\":\"test[:10000]\", \"test\":\"test[10000:]\"}\n",
        "  elif tfds_name in VTAB_TASKS or (\n",
        "      tfds_name.endswith(\"/1k\") and tfds_name.replace(\"/1k\", \"\") in VTAB_TASKS):\n",
        "    is_vtab_1k = tfds_name.endswith(\"/1k\")\n",
        "    tfds_name = tfds_name.replace(\"/1k\", \"\")\n",
        "    registry_name = {\n",
        "        \"diabetic_retinopathy_detection/btgraham-300\": \"diabetic_retinopathy\",\n",
        "        \"svhn_cropped\": \"svhn\",\n",
        "        \"cifar100\": \"cifar\",\n",
        "        \"cifar10\": \"cifar\",\n",
        "    }.get(tfds_name, tfds_name.split(\"/\")[0])\n",
        "    args = {\n",
        "        \"clevr/count_all\": (\"count_all\",),\n",
        "        \"clevr/count_cylinders\": (\"count_cylinders\",),\n",
        "        \"clevr/closest_object_distance\": (\"closest_object_distance\",),\n",
        "        \"dsprites/label_x_position\": (\"label_x_position\",),\n",
        "        \"dsprites/label_orientation\": (\"label_orientation\",),\n",
        "        \"kitti/closest_object_distance\": (\"closest_object_distance\",),\n",
        "        \"kitti/count_vehicles\": (\"count_vehicles\",),\n",
        "        \"kitti/closest_vehicle_distance\": (\"closest_vehicle_distance\",),\n",
        "        \"smallnorb/label_category\": (\"label_category\",),\n",
        "        \"smallnorb/label_lighting\": (\"label_lighting\",),\n",
        "        \"smallnorb/label_azimuth\": (\"label_azimuth\",),\n",
        "        \"smallnorb/label_elevation\": (\"label_elevation\",),\n",
        "        \"cifar100\": (100,),\n",
        "        \"cifar10\": (10,),\n",
        "    }.get(tfds_name, ())\n",
        "    vtab_class = task_adapt_registry.Registry.lookup(\n",
        "        f\"data.{registry_name}\")(*args)\n",
        "    vtab_splits = vtab_class._tfds_splits\n",
        "    dataset = {\n",
        "        \"caltech101\": \"caltech101:3.*.*\",\n",
        "        \"dtd\": \"dtd:3.*.*\",\n",
        "        \"oxford_flowers102\": \"oxford_flowers102:2.*.*\",\n",
        "        \"oxford_iiit_pet\": \"oxford_iiit_pet:3.*.*\",\n",
        "        \"sun397\": \"sun397/tfds:4.*.*\",\n",
        "        \"svhn\": \"svhn_cropped:3.*.*\",\n",
        "        \"patch_camelyon\": \"patch_camelyon:2.*.*\",\n",
        "        \"eurosat\": \"eurosat/rgb:2.*.*\",\n",
        "        \"resisc45\": \"resisc45:3.*.*\",\n",
        "        \"diabetic_retinopathy\": \"diabetic_retinopathy_detection/btgraham-300:3.*.*\",\n",
        "        \"clevr\": \"clevr:3.*.*\",\n",
        "        \"dmlab\": \"dmlab:2.0.1\",\n",
        "        \"dsprites\": \"dsprites:2.*.*\",\n",
        "        \"kitti\": \"kitti:3.2.0\",\n",
        "        \"smallnorb\": \"smallnorb:2.*.*\",\n",
        "        \"cifar\" : \"cifar100:3.*.*\" if tfds_name == \"cifar100\" else \"cifar10:3.*.*\",\n",
        "    }[registry_name]\n",
        "    if is_vtab_1k:\n",
        "      splits =  {\n",
        "          \"train\": str(vtab_splits[\"train800\"]),\n",
        "          \"validation\": str(vtab_splits[\"val200\"]),\n",
        "          \"test\": str(vtab_splits[\"test\"]),\n",
        "          }\n",
        "    else:\n",
        "      splits =  {\n",
        "          \"train\": str(vtab_splits[\"train\"]),\n",
        "          \"validation\": str(vtab_splits[\"val\"]),\n",
        "          \"test\": str(vtab_splits[\"test\"]),\n",
        "          }\n",
        "  else:\n",
        "    dataset = tfds_name\n",
        "    splits = get_default_splits(tfds_name)\n",
        "  return dataset, splits, vtab_class\n",
        "\n",
        "\n",
        "class Task():\n",
        "  def __init__(self, name, config):\n",
        "    self.config = config\n",
        "\n",
        "    self.dataset, self.splits, self.vtab_class = get_dataset_and_splits(name)\n",
        "    self.name = name\n",
        "    if self.vtab_class:\n",
        "      self.num_classes = self.vtab_class.get_num_classes()\n",
        "    else:\n",
        "      self.num_classes = self.get_builder(\n",
        "          \"train\").info.features[self.get_label_key()].num_classes\n",
        "    num_train_examples = self.get_builder(\n",
        "        \"train\").info.splits[self.splits[\"train\"]].num_examples\n",
        "    self.train_batch_size = config.batch_size\n",
        "    self.num_train_batches_between_validations = math.ceil(\n",
        "        min(num_train_examples,\n",
        "            config.num_train_examples_between_validations_max)\n",
        "        / self.train_batch_size)\n",
        "\n",
        "    num_validation_examples_tot = self.get_builder(\n",
        "        \"validation\").info.splits[self.splits[\"validation\"]].num_examples\n",
        "    if config.num_validation_examples_max \u003c= num_validation_examples_tot:\n",
        "      self.validation_batch_size = config.batch_size\n",
        "      self.num_validation_batches = math.floor(\n",
        "          config.num_validation_examples_max / self.validation_batch_size)\n",
        "    else:\n",
        "      # Adjust batch_size and num_batches to cover the smaller validation sets.\n",
        "      self.num_validation_batches = math.ceil(\n",
        "          num_validation_examples_tot / config.batch_size)\n",
        "      self.validation_batch_size = math.floor(\n",
        "          num_validation_examples_tot / self.num_validation_batches)\n",
        "      assert num_validation_examples_tot \u003e= (\n",
        "          self.num_validation_batches*self.validation_batch_size)\n",
        "    self.num_validation_examples = (\n",
        "        self.num_validation_batches * self.validation_batch_size)\n",
        "\n",
        "    print(f\"Task: {self.name}\")\n",
        "    print(f\"  Train batches between validations: {self.num_train_batches_between_validations}\")\n",
        "    print(f\"  Validation batches: {self.num_validation_batches}\")\n",
        "    print(f\"  Validation batch size: {self.validation_batch_size}\")\n",
        "    print(f\"  Dataset {{\\n{self.dataset}}}\")\n",
        "    print(f\"  Splits {{\\n{self.splits}}}\")\n",
        "\n",
        "  def get_label_key(self):\n",
        "    return {\n",
        "        \"stanford_online_products\": \"super_class_id\",\n",
        "        }.get(self.name, \"label\")\n",
        "\n",
        "  def get_builder(self, mode):\n",
        "    if type(self.dataset) == str:\n",
        "      return get_tfds_builder(self.dataset)\n",
        "    return get_tfds_builder(self.dataset[mode])\n",
        "\n",
        "  def get_ds(self, mode, hparams):\n",
        "    data = self.get_builder(mode).as_dataset(\n",
        "        split=self.splits[mode],\n",
        "        shuffle_files=mode==\"train\")\n",
        "\n",
        "    def _pp(data):\n",
        "      im = data[\"image\"]\n",
        "      tf.debugging.assert_type(im, tf.uint8)\n",
        "\n",
        "      if mode == \"train\":\n",
        "        if hparams.get(\"ds_quality_delta\", 0.0) \u003e 0.0:\n",
        "          im = tf.image.random_jpeg_quality(\n",
        "              im,\n",
        "              min_jpeg_quality=int(100 * (1 - hparams[\"ds_quality_delta\"])),\n",
        "              max_jpeg_quality=100)\n",
        "\n",
        "      # Must have 3 channels.\n",
        "      if im.shape[-1] == 1:\n",
        "        im = tf.squeeze(tf.stack([im] * 3, -1), axis=-2)\n",
        "      assert im.shape[-1] == 3\n",
        "      im = tf.cast(im, tf.float32)\n",
        "      if mode == \"train\":\n",
        "        if hparams.get(\"ds_area_range_min\", 1.0) \u003c 1.0:\n",
        "          channels = im.shape[-1]\n",
        "          begin, size, _ = tf.image.sample_distorted_bounding_box(\n",
        "              tf.shape(im),\n",
        "              tf.zeros([0, 0, 4], tf.float32),\n",
        "              aspect_ratio_range=[hparams[\"ds_aspect_ratio_range_min\"],\n",
        "                                  1.0/hparams[\"ds_aspect_ratio_range_min\"]],\n",
        "              area_range=[hparams[\"ds_area_range_min\"], 1.0],\n",
        "              # Overlap with bounding box, the bounding box should anyway\n",
        "              # default defaults to whole image in this case.\n",
        "              min_object_covered=0,\n",
        "              use_image_if_no_bounding_boxes=True)\n",
        "          im = tf.slice(im, begin, size)\n",
        "          # Restore the depth-dimension lost by the above operation.\n",
        "          im.set_shape([None, None, channels])\n",
        "        if hparams.get(\"ds_flip_left_right\", False):\n",
        "          if tf.random.uniform(shape=[]) \u003e 0.5:\n",
        "            im = tf.image.flip_left_right(im)\n",
        "        if hparams.get(\"ds_brightness_delta\", 0.0) \u003e 0.0:\n",
        "          im = tf.image.random_brightness(\n",
        "              im, max_delta=hparams[\"ds_brightness_delta\"])\n",
        "        if hparams.get(\"ds_contrast_delta\", 0.0) \u003e 0.0:\n",
        "          im = tf.image.random_contrast(\n",
        "              im, lower=1-hparams[\"ds_contrast_delta\"],\n",
        "              upper=1+hparams[\"ds_contrast_delta\"])\n",
        "        if hparams.get(\"ds_saturation_delta\", 0.0) \u003e 0.0:\n",
        "          im = tf.image.random_saturation(\n",
        "              im, lower=1-hparams[\"ds_saturation_delta\"],\n",
        "              upper=1 + hparams[\"ds_saturation_delta\"])\n",
        "        if hparams.get(\"ds_hue_delta\", 0.0) \u003e 0.0:\n",
        "          im = tf.image.random_hue(im, max_delta=hparams[\"ds_hue_delta\"])\n",
        "\n",
        "      def get_formatted_image(image, image_size):\n",
        "        image = tf.image.resize(image, [image_size, image_size])\n",
        "        # Values in range [-1 , 1].\n",
        "        image = image / 127.5 - 1\n",
        "        image = tf.clip_by_value(image, -1, 1)\n",
        "        return image\n",
        "\n",
        "      if type(hparams[\"ds_image_size\"]) is list:\n",
        "        out_im = {}\n",
        "        for im_size in hparams[\"ds_image_size\"]:\n",
        "          out_im[str(im_size)] = get_formatted_image(im, int(im_size))\n",
        "      else:\n",
        "        out_im = get_formatted_image(im, int(hparams[\"ds_image_size\"]))\n",
        "      return {\"image\": out_im,\n",
        "              \"label\": data[self.get_label_key()]}\n",
        "\n",
        "    if mode == \"validation\":\n",
        "      data = data.take(self.num_validation_examples).cache()\n",
        "    if mode != \"test\":\n",
        "      data = data.repeat()\n",
        "    if self.vtab_class and self.vtab_class._base_preprocess_fn:\n",
        "      data = data.map(self.vtab_class._base_preprocess_fn, tf.data.AUTOTUNE)\n",
        "    data = data.map(_pp, tf.data.AUTOTUNE)\n",
        "    if mode == \"train\":\n",
        "      batch_size = self.train_batch_size\n",
        "    else:\n",
        "      batch_size = self.validation_batch_size\n",
        "    data = data.batch(batch_size)\n",
        "    if mode == \"train\":\n",
        "      data = data.shuffle(10)\n",
        "    return tfds.as_numpy(data.prefetch(tf.data.AUTOTUNE))\n",
        "\n",
        "def get_task_factory_fn(config):\n",
        "  def get_task(task_name: str):\n",
        "    return Task(name=task_name, config=config)\n",
        "  return get_task"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GdzUzvhrWLer"
      },
      "source": [
        "# Components"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rzxoZ4rdQZcA"
      },
      "outputs": [],
      "source": [
        "def params2comps(params, train_locks, agent_id, name=None):\n",
        "  \"\"\"Convert frozend dict of params to a list of components.\"\"\"\n",
        "  components = []\n",
        "  for k in params:\n",
        "    if name is None or name == k:\n",
        "      c = Component(\n",
        "          name=k, agent_id=agent_id,\n",
        "          params=params[k], train_locks=train_locks)\n",
        "      components.append(c)\n",
        "  return components"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DNxjDX13_dm_"
      },
      "outputs": [],
      "source": [
        "def fingerprint_params(params):\n",
        "  return np.sum(np.array(jax.tree_util.tree_leaves(\n",
        "      jax.tree_util.tree_map(jnp.sum, params))))\n",
        "\n",
        "class Component():\n",
        "  counter = 0\n",
        "  # Components of retained paths with id \u003c= last_saved are saved in checkpoint.\n",
        "  last_saved = -1\n",
        "\n",
        "  def reset_globals():\n",
        "    Component.counter = 0\n",
        "    Component.last_saved = -1\n",
        "\n",
        "  def __init__(\n",
        "      self, name: str, agent_id: str, params, train_locks):\n",
        "    self.name = name\n",
        "    self.agent_id = agent_id\n",
        "    self.params = jax.device_get(params)\n",
        "    self.num_params = None\n",
        "    self.train_locks = set(train_locks)\n",
        "    self.id = Component.counter\n",
        "    Component.counter += 1\n",
        "\n",
        "  def get_num_params(self):\n",
        "    if self.num_params is None:\n",
        "      self.num_params = get_num_params(self.params)\n",
        "    return self.num_params\n",
        "\n",
        "  def fingerprint(self):\n",
        "    return fingerprint_params(self.params)\n",
        "\n",
        "  def is_trainable(self):\n",
        "    return len(self.train_locks) == 0\n",
        "\n",
        "  def clone(self, agent_id):\n",
        "    return Component(name=self.name,\n",
        "                     agent_id=agent_id,\n",
        "                     params=copy.deepcopy(jax.device_get(self.params)),\n",
        "                     train_locks=set())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SdMf8WcYBbjt"
      },
      "outputs": [],
      "source": [
        "class ComponentPath():\n",
        "  \"\"\"Wraps a Paths to be used as a Component.\"\"\"\n",
        "  def __init__(self, name, path):\n",
        "    self.name = name\n",
        "    self.path = path\n",
        "    self.train_locks = set([\"FROZEN\"])\n",
        "\n",
        "  def is_trainable(self):\n",
        "    return False\n",
        "\n",
        "  @property\n",
        "  def params(self):\n",
        "    return flax.core.freeze(self.path.get_all_params())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pCie1hQLUyP7"
      },
      "source": [
        " # Paths \u0026 Population"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8PU5ffvd_gC9"
      },
      "outputs": [],
      "source": [
        "class Path():\n",
        "  def reset_globals(config):\n",
        "    Path.config = config\n",
        "    Path.counter = 0\n",
        "    Path.last_saved = -1\n",
        "    Path.paths = []\n",
        "    Path.scorer = globals()[config.get(\"scorer_class\", \"ScorerDecay\")](\n",
        "        **config.get(\"scorer_kwargs\", {}))\n",
        "    # Cache output of functions calls with same args.\n",
        "    Path.cached_tasks = ObjectCache(get_task_factory_fn(config))\n",
        "    Path.cached_optimizers = ObjectCache(get_optimizer)\n",
        "\n",
        "  def __init__(self, hparams, components, parent, agent_id, task_name):\n",
        "    self.components = components\n",
        "    self.id = Path.counter\n",
        "    Path.counter += 1\n",
        "    self.agent_id = agent_id\n",
        "    self.task_name = task_name\n",
        "    self.parent = parent\n",
        "    self.hparams = hparams\n",
        "    self._model = None\n",
        "    self.metrics = {\n",
        "        \"generation\": 0 if parent is None else parent.metrics[\"generation\"] + 1,\n",
        "    }\n",
        "    Path.paths.append(self)\n",
        "\n",
        "  @property\n",
        "  def task(self):\n",
        "    return Path.cached_tasks(task_name=self.task_name)\n",
        "\n",
        "  @property\n",
        "  def model_factory(self):\n",
        "    return get_agent_class(self.agent_id).get_model_factory()\n",
        "\n",
        "  @property\n",
        "  def model(self):\n",
        "    if self._model == None:\n",
        "      self._model = self.model_factory.get_model(self.hparams, self.config)\n",
        "    return self._model\n",
        "\n",
        "  @property\n",
        "  def full_id(self):\n",
        "    return f\"{self.agent_id}:{self.id}\"\n",
        "\n",
        "  def comps_only(self):  # Exclude wrapped paths.\n",
        "    return [c for c in self.components if c.__class__ is Component]\n",
        "\n",
        "  def score(self):\n",
        "    return Path.scorer.score(self)\n",
        "\n",
        "  def get_all_params(self):\n",
        "    params = {}\n",
        "    for c in self.components:\n",
        "      assert c.name not in params, c.name\n",
        "      params[c.name] = c.params\n",
        "    return flax.core.freeze(params)\n",
        "\n",
        "  def get_trainable_params(self):\n",
        "    params = {}\n",
        "    for c in self.components:\n",
        "      if c.is_trainable():\n",
        "        assert c.name not in params, c.name\n",
        "        params[c.name] = c.params\n",
        "    return flax.core.freeze(params)\n",
        "\n",
        "  def get_fixed_params(self):\n",
        "    params = {}\n",
        "    for c in self.components:\n",
        "      if not c.is_trainable():\n",
        "        assert c.name not in params, c.name\n",
        "        params[c.name] = c.params\n",
        "    return flax.core.freeze(params)\n",
        "\n",
        "  def update_trainable(self, trained_params):\n",
        "    trainable_count = 0\n",
        "    for c in self.components:\n",
        "      if c.is_trainable():\n",
        "        trainable_count += 1\n",
        "        assert c.name in trained_params.keys()\n",
        "        c.params = trained_params[c.name]\n",
        "    assert len(trained_params.keys()) == trainable_count, (\n",
        "        f\"{len(trained_params.keys())} {trainable_count}\")\n",
        "\n",
        "  def get_num_accounted_params(self):\n",
        "    rtn = 0\n",
        "    for c in self.components:\n",
        "      tl = copy.copy(c.train_locks)\n",
        "      assert type(tl) is set\n",
        "      tl.add(self.agent_id)\n",
        "      assert tl\n",
        "      rtn += c.get_num_params() / len(tl)\n",
        "    return rtn\n",
        "\n",
        "  def get_flops(self):\n",
        "    return compute_flops_hlo(\n",
        "          partial(self.model.apply, train=False),\n",
        "          {\"params\": self.model_factory.get_comps2model_fn()(merge_params(\n",
        "              self.get_trainable_params(),\n",
        "              self.get_fixed_params()))},\n",
        "          self.model_factory.get_sample_input(self.hparams))\n",
        "\n",
        "  def get_optimizer(self):\n",
        "    return Path.cached_optimizers(\n",
        "        num_train_batches_between_validations=\n",
        "            self.task.num_train_batches_between_validations,\n",
        "        num_validations_per_path_training=\n",
        "            self.task.config.num_validations_per_path_training,\n",
        "        **self.hparams)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bkRmcJgzUbwN"
      },
      "outputs": [],
      "source": [
        "class ScorerDecay():\n",
        "  def __init__(self, scale_factor=1, base_accounted_params=0, base_flops=0):\n",
        "    assert 0.0 \u003c scale_factor \u003c= 1.0\n",
        "    self.scale_factor = scale_factor\n",
        "    self.base_accounted_params = base_accounted_params\n",
        "    self.base_flops = base_flops\n",
        "\n",
        "  def score(self, path):\n",
        "    if (\"quality\" not in path.metrics\n",
        "        or math.isnan(path.metrics[\"quality\"])):\n",
        "      return None\n",
        "    assert path.metrics[\"quality\"] \u003e= 0, (\n",
        "        f\"{path.task_name} {path.metrics['quality']}\")\n",
        "    score = path.metrics[\"quality\"]\n",
        "    if self.base_accounted_params \u003e 0:\n",
        "      # Accounted params needs to be updated since it depends on the\n",
        "      # changing structure of the system.\n",
        "      path.metrics[\"accounted_params\"] = path.get_num_accounted_params()\n",
        "      score *= self.scale_factor ** (\n",
        "          path.metrics[\"accounted_params\"] / self.base_accounted_params)\n",
        "    if self.base_flops \u003e 0:\n",
        "      if \"flops\" not in path.metrics:\n",
        "        path.metrics[\"flops\"] = path.get_flops()\n",
        "      score *= self.scale_factor ** (path.metrics[\"flops\"] / self.base_flops)\n",
        "    assert score \u003e= 0\n",
        "    path.metrics[\"score\"] = score\n",
        "    return score"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YMbYgKd8_nyi"
      },
      "outputs": [],
      "source": [
        "class Population():\n",
        "  def __init__(self, config):\n",
        "    Path.reset_globals(config)\n",
        "    Component.reset_globals()\n",
        "    self.paths = defaultdict(list)\n",
        "    self.config = config\n",
        "    self.paths_df = pd.DataFrame()\n",
        "    self.comps_df = pd.DataFrame()\n",
        "\n",
        "  def get_best_path(self):\n",
        "    if len(self.paths[self.config.agent_id]) == 0:\n",
        "      return None\n",
        "    # Oldest path achieving max score.\n",
        "    return max(sorted(self.paths[self.config.agent_id], key=lambda p: p.id, reverse=False), key=lambda p: p.score())\n",
        "\n",
        "  def prune_population(self):\n",
        "    if self.config.get(\"max_task_population_size\", None) and (\n",
        "        len(self.paths[self.config.agent_id]) \u003e self.config.max_task_population_size):\n",
        "      self.paths[self.config.agent_id] = sorted(\n",
        "          self.paths[self.config.agent_id], key=lambda p: p.score(), reverse=True\n",
        "          )[:self.config.max_task_population_size]\n",
        "\n",
        "  def add_train_locks(self):\n",
        "    # Check.\n",
        "    for ps in self.paths.values():\n",
        "      for p in ps:\n",
        "        for c in p.components:\n",
        "          assert self.config.agent_id not in c.train_locks\n",
        "    # Add locks.\n",
        "    paths = self.paths[self.config.agent_id]\n",
        "    for p in paths:\n",
        "      for c in p.components:\n",
        "        c.train_locks.add(self.config.agent_id)\n",
        "\n",
        "  def rm_train_locks(self):\n",
        "    # Remove locks.\n",
        "    paths = self.paths[self.config.agent_id]\n",
        "    for p in paths:\n",
        "      for c in p.components:\n",
        "        if self.config.agent_id in c.train_locks:\n",
        "          c.train_locks.remove(self.config.agent_id)\n",
        "    # Check.\n",
        "    for ps in self.paths.values():\n",
        "      for p in ps:\n",
        "        for c in p.components:\n",
        "          assert self.config.agent_id not in c.train_locks\n",
        "\n",
        "  def start_cycle(self):\n",
        "    self.rm_train_locks()\n",
        "\n",
        "  def end_cycle(self):\n",
        "    # Keep only best one.\n",
        "    best_path = self.get_best_path()\n",
        "    assert best_path is not None\n",
        "    best_path.metrics[\"num_cycles\"] = best_path.metrics.get(\"num_cycles\", 0) + 1\n",
        "    self.paths[self.config.agent_id] = [best_path]\n",
        "    self.add_train_locks()\n",
        "    self.garbage_collect_paths()\n",
        "\n",
        "  def garbage_collect_paths(self):\n",
        "    # Store history before dropping references to unused paths to trigger\n",
        "    # garbage collection of components and parameters.\n",
        "    self.paths_df = self.paths_df.append(\n",
        "        paths_to_df(Path.paths), ignore_index=True\n",
        "        ).query(f'agent_id==\"{self.config.agent_id}\" and id\u003e{Path.last_saved}'\n",
        "        # Drop duplicates generated by reloads, notice that some state-based\n",
        "        # metrics may vary (e.g. accounted parameters) so we match only id\n",
        "        # (agent_id is already matched from the preceding query).\n",
        "        ).drop_duplicates(\"id\")\n",
        "    self.comps_df = self.comps_df.append(\n",
        "        components_to_df(Path.paths), ignore_index=True\n",
        "        ).query(f'agent_id==\"{self.config.agent_id}\" and id\u003e{Component.last_saved}'\n",
        "        ).drop_duplicates()\n",
        "    # Drop unused paths generated in this agent cycle for garbage collection.\n",
        "    Path.paths = []\n",
        "    # Simplify ancestor tree to contain only live paths.\n",
        "    # Notice that the simplification is done also for paths of other tasks,\n",
        "    # since they may be pointing to a path of this task that was discarded.\n",
        "    live_paths_ids = [p.full_id for paths in self.paths.values() for p in paths]\n",
        "    for path in [path for paths in self.paths.values() for path in paths]:\n",
        "      ancestor = path.parent\n",
        "      if ancestor is None:\n",
        "        continue\n",
        "      while True:\n",
        "        if ancestor.full_id in live_paths_ids:\n",
        "          path.parent = ancestor\n",
        "          break\n",
        "        ancestor = ancestor.parent\n",
        "\n",
        "  def get_path_from_full_id(self, agent_id, path_id):\n",
        "    for p in self.paths[agent_id]:\n",
        "      if p.id == path_id:\n",
        "        return p\n",
        "    assert False, f\"Path not found {agent_id}:{path_id}\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4EgQHbawpNcS"
      },
      "outputs": [],
      "source": [
        "pd.set_option(\"display.expand_frame_repr\", False)\n",
        "pd.set_option(\"display.max_columns\", 100)\n",
        "pd.set_option(\"display.max_rows\", 100)\n",
        "\n",
        "def pop_to_df(pop):\n",
        "  return paths_to_df([p for paths in pop.paths.values() for p in paths])\n",
        "\n",
        "def paths_to_df(paths):\n",
        "  # Collect all metrics names.\n",
        "  metrics_keys = set()\n",
        "  hparams_keys = set()\n",
        "  for path in paths:\n",
        "    path.score()  # Update scores.\n",
        "    metrics_keys.update(path.metrics)\n",
        "    hparams_keys.update(path.hparams)\n",
        "\n",
        "  def _format(x):\n",
        "    if type(x) in [dict, list]:\n",
        "      return json.dumps(x)\n",
        "    return x\n",
        "\n",
        "  data = defaultdict(list)\n",
        "  for path in paths:\n",
        "    data[\"agent_id\"].append(path.agent_id)\n",
        "    data[\"task_name\"].append(path.task_name)\n",
        "    data[\"id\"].append(path.id)\n",
        "    data[\"parent_id\"].append(path.parent.id if path.parent else -1)\n",
        "    data[\"parent_agent_id\"].append(path.parent.agent_id if path.parent else None)\n",
        "    data[\"components\"].append(\",\".join(\n",
        "        [f\"{c.agent_id}:{c.id}\" for c in path.comps_only()]))\n",
        "    for k in hparams_keys:\n",
        "      data[f\"hparams.{k}\"].append(_format(path.hparams[k]) if k in path.hparams else None)\n",
        "    for k in metrics_keys:\n",
        "      data[f\"metrics.{k}\"].append(path.metrics[k] if k in path.metrics else None)\n",
        "  return pd.DataFrame(data)\n",
        "\n",
        "def components_to_df(paths):\n",
        "  # Collect all components.\n",
        "  comps = set()\n",
        "  for p in paths:\n",
        "    comps.update(p.comps_only())\n",
        "\n",
        "  data = defaultdict(list)\n",
        "  for c in comps:\n",
        "    data[\"id\"].append(c.id)\n",
        "    data[\"name\"].append(c.name)\n",
        "    data[\"agent_id\"].append(c.agent_id)\n",
        "    data[\"num_params\"].append(c.get_num_params())\n",
        "  return pd.DataFrame(data)\n",
        "\n",
        "def print_df_segments(df, segment_length:int = 5):\n",
        "  tot_length = df.shape[0]\n",
        "  # Pad column title with spaces to keep alignment across segments.\n",
        "  def prepend_spaces(original_str, pad_to_len):\n",
        "    return \" \" * (pad_to_len-len(original_str)) + original_str\n",
        "  pad_to_len = max([len(tn) for tn in set(df[\"agent_id\"].to_list())])+1\n",
        "  df = df.rename(columns={\n",
        "    \"agent_id\": prepend_spaces(\"agent_id\", pad_to_len),\n",
        "    \"task_name\": prepend_spaces(\"task_name\", pad_to_len),\n",
        "    \"parent_agent_id\": prepend_spaces(\"parent_agent_id\", pad_to_len),\n",
        "    })\n",
        "  for x in range(0, tot_length, segment_length):\n",
        "    print(df[x:min(x+segment_length, tot_length)])\n",
        "\n",
        "def df_leaderboard(df):\n",
        "  # Place columns on the left for readability.\n",
        "  all_keys = sorted(df.columns.tolist())\n",
        "  first_keys = [\"agent_id\", \"task_name\", \"metrics.test_quality\", \"metrics.score\",\n",
        "                \"metrics.quality\", \"metrics.accounted_params\", \"metrics.flops\",\n",
        "                \"id\", \"parent_id\", \"parent_agent_id\"]\n",
        "  first_keys = [k for k in first_keys if k in all_keys]\n",
        "  sorted_keys = first_keys + [k for k in all_keys if k not in first_keys]\n",
        "  # Filter mu function parameters.\n",
        "  sorted_keys = [k for k in sorted_keys if \"_mu_|\" not in k]\n",
        "  df = df[sorted_keys]\n",
        "  if \"metrics.score\" in df:\n",
        "    df = df.sort_values([\"agent_id\", \"metrics.score\"], ascending=[True, False], ignore_index=True)\n",
        "  else:\n",
        "    df = df.sort_values(\"agent_id\", ignore_index=True)\n",
        "  print_df_segments(df)\n",
        "  for k in [\"metrics.score\", \"metrics.quality\", \"metrics.test_quality\"]:\n",
        "    if k in df:\n",
        "      print(f\"Avg {k}: {df[k].mean():.6f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rDU0butVlu2f"
      },
      "source": [
        "# Checkpointing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "104ReuGLspoJ"
      },
      "outputs": [],
      "source": [
        "def df_write_to_csv(df, dir_path, df_name):\n",
        "  filename_df = os.path.join(dir_path, f\"{df_name}.csv\")\n",
        "  with gfile.GFile(filename_df, \"w\") as outfile:\n",
        "    df.to_csv(outfile, index=False)\n",
        "\n",
        "def df_read_from_csv(dir_path, df_name):\n",
        "  filename_df = os.path.join(dir_path, f\"{df_name}.csv\")\n",
        "  with gfile.GFile(filename_df, \"r\") as infile:\n",
        "    df = pd.read_csv(infile)\n",
        "  # Pandas read_csv() reads empty stings as NaNs. Set NaNs to empty strings in\n",
        "  # columns with type strings/object.\n",
        "  for c in df.columns:\n",
        "    if df[c].dtype == np.object_:\n",
        "        df[c].fillna(\"\", inplace=True)\n",
        "  return df\n",
        "\n",
        "def get_comps_params_to_save(pop):\n",
        "  comps_params = {}\n",
        "  # All components generated by this agent.\n",
        "  all_comps = set(\n",
        "      [c for p in pop.paths[pop.config.agent_id] for c in p.comps_only() if c.agent_id == pop.config.agent_id])\n",
        "  # Check that there are not duplicate ids.\n",
        "  assert len(all_comps) == len(set([c.id for c in all_comps])), (\n",
        "      [f\"{c.name}:{c.agent_id}:{c.id}\" for c in all_comps])\n",
        "  for c in all_comps:\n",
        "    if c.id \u003c= Component.last_saved:\n",
        "      continue\n",
        "    assert c.agent_id == pop.config.agent_id\n",
        "    c_id_string = f\"{c.name}:{c.agent_id}:{c.id}\"\n",
        "    comps_params[c_id_string] = c.params\n",
        "  return comps_params"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z_AmkV7hLeQp"
      },
      "outputs": [],
      "source": [
        "def latest_checkpoint(ckpt_dir, prefix = \"checkpoint_\"):\n",
        "  ckpt_dir = os.fspath(ckpt_dir)\n",
        "  glob_path = os.path.join(ckpt_dir, f\"{prefix}*\")\n",
        "  checkpoint_files = flax_checkpoints.natural_sort(gfile.glob(glob_path))\n",
        "  checkpoint_files = [f for f in checkpoint_files if not f.endswith(\"_tmp\")]\n",
        "  return checkpoint_files[-1] if checkpoint_files else None"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M_RTFjoOwfpM"
      },
      "outputs": [],
      "source": [
        "def save_checkpoint(ckpt_dir, comps_params, cycle_id, generation_id):\n",
        "  print(\"SAVING\", cycle_id, generation_id, comps_params.keys())\n",
        "  # Write checkpoint.\n",
        "  flax_checkpoints.save_checkpoint(\n",
        "      ckpt_dir=ckpt_dir,\n",
        "      target=comps_params,\n",
        "      step=generation_id,\n",
        "      prefix=f\"checkpoint_{cycle_id}_\",\n",
        "      overwrite=True)\n",
        "  # Delete intermediate checkpoint directories.\n",
        "  if generation_id == 0:\n",
        "    intermediate_ckpt_dirs = gfile.glob(\n",
        "        os.path.join(os.path.dirname(ckpt_dir), \"state_*_[^0]*\"))\n",
        "    for d in intermediate_ckpt_dirs:\n",
        "      print(\"Deleting intermediate checkpoint:\", d)\n",
        "      gfile.rmtree(d)\n",
        "\n",
        "def save_state(agent):\n",
        "  pop = agent.pop\n",
        "  cycle_id = agent.cycle_id\n",
        "  generation_id = agent.generation_id\n",
        "  config = agent.config\n",
        "  write_start = time.time()\n",
        "  # Save data needed to resume exp.\n",
        "  pop.garbage_collect_paths()\n",
        "  state_dir = os.path.join(config.agent_dir, f\"state_{cycle_id}_{generation_id}\")\n",
        "  gfile.makedirs(state_dir)\n",
        "  assert not latest_checkpoint(state_dir), f\"Checkpoint already present in forlder: {state_dir}\"\n",
        "  print(\"WRITING CHECKPOINT:\", cycle_id, generation_id)\n",
        "  df_write_to_csv(paths_to_df(agent.get_paths_to_publish()), state_dir, \"published\")\n",
        "  df_write_to_csv(paths_to_df([p for paths in pop.paths.values() for p in paths]), state_dir, \"population\")\n",
        "  df_write_to_csv(pop.paths_df, state_dir, \"paths\")\n",
        "  df_write_to_csv(pop.comps_df, state_dir, \"components\")\n",
        "  json.dump(config.as_configdict().to_dict(), gfile.GFile(os.path.join(state_dir, \"config.json\"), \"w\"), indent=2)\n",
        "  save_checkpoint(state_dir, get_comps_params_to_save(pop), cycle_id, generation_id)\n",
        "  # Update last saved.\n",
        "  if generation_id == 0:\n",
        "    Path.last_saved = pop.paths_df.id.max()\n",
        "    Component.last_saved = pop.comps_df.id.max()\n",
        "  print(f\"STATE WRITE TIME: {time.time() - write_start:.2f} s\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7yWdy7DskBph"
      },
      "outputs": [],
      "source": [
        "def load_paths(pop, state_dir, all_agents_dirs):\n",
        "  if state_dir:\n",
        "    state_dir = state_dir.rstrip(\"/\")\n",
        "  load_start = time.time()\n",
        "\n",
        "  # Load system state info.\n",
        "  population_df = pd.DataFrame()\n",
        "  skip_agent_dir = None\n",
        "  if state_dir:\n",
        "    # Load agent state, possibly intermediate.\n",
        "    population_df = population_df.append(df_read_from_csv(state_dir, \"published\"))\n",
        "    skip_agent_dir = os.path.dirname(state_dir)\n",
        "  for agent_dir in all_agents_dirs:\n",
        "    if agent_dir == skip_agent_dir:\n",
        "      continue\n",
        "    agent_checkpoint = latest_checkpoint(os.path.join(agent_dir, \"state_*_0/\"))\n",
        "    if agent_checkpoint:\n",
        "      population_df = population_df.append(\n",
        "          df_read_from_csv(os.path.dirname(agent_checkpoint), \"published\"))\n",
        "\n",
        "  # Load parameters from sharded system checkpoint.\n",
        "  loaded_params = {}  # Dictionary to accumlate loaded parameters.\n",
        "  lock = Lock()\n",
        "  duplicate_keys = set()\n",
        "  def append_loaded_params(add_chkp_dir: str):\n",
        "    if latest_checkpoint(add_chkp_dir) is None:\n",
        "      return  # Skip folders without a completed checkpoint.\n",
        "    lp_add = flax_checkpoints.restore_checkpoint(\n",
        "        ckpt_dir=add_chkp_dir,\n",
        "        target=None)\n",
        "    if lp_add:\n",
        "      lock.acquire()\n",
        "      print(\"LOADED COMPONENTS\", add_chkp_dir, lp_add.keys())\n",
        "      duplicate_keys.update(loaded_params.keys() \u0026 lp_add.keys())\n",
        "      loaded_params.update(lp_add)\n",
        "      lock.release()\n",
        "  all_state_dirs = []\n",
        "  if state_dir:\n",
        "    # Include active agent state, possibly intermediate.\n",
        "    all_state_dirs.append(state_dir)\n",
        "    all_state_dirs.extend(gfile.glob(os.path.join(os.path.dirname(state_dir), \"state_*_0\")))\n",
        "  for agent_dir in all_agents_dirs:\n",
        "    all_state_dirs.extend(gfile.glob(os.path.join(agent_dir, \"state_*_0\")))\n",
        "  threads = []\n",
        "  for s_dir in set(all_state_dirs):\n",
        "    threads.append(Thread(target=append_loaded_params, args=(s_dir,)))\n",
        "    threads[-1].start()\n",
        "  for t in threads:\n",
        "    t.join()\n",
        "  assert not duplicate_keys, duplicate_keys\n",
        "  print(f\"LOAD TIME: {time.time() - load_start:.2f} s\")\n",
        "  frozen_params = flax.core.freeze(loaded_params)\n",
        "  sid_2_comp = {}\n",
        "  for k in frozen_params.keys():\n",
        "    assert len(k.split(\":\")) == 3, k\n",
        "    name, agent_id, id = k.split(\":\")\n",
        "    c = Component(\n",
        "        name=name, agent_id=agent_id, params=frozen_params[k], train_locks=[])\n",
        "    c.id = int(id)\n",
        "    source_id = f\"{agent_id}:{id}\"\n",
        "    assert source_id not in sid_2_comp, source_id\n",
        "    sid_2_comp[source_id] = c\n",
        "  # For parent assignemt.\n",
        "  sid_2_path = {}\n",
        "  path_2_parent_sid = {}\n",
        "  for index, row in population_df.iterrows():\n",
        "    agent_id = row[\"agent_id\"]\n",
        "    path_id = int(row[\"id\"])\n",
        "    path_sid = f\"{agent_id}:{path_id}\"\n",
        "    if path_sid in sid_2_path:\n",
        "      continue\n",
        "    comps_sids = row[\"components\"].split(\",\")\n",
        "    comps = []\n",
        "    for sid in comps_sids:\n",
        "      comps.append(sid_2_comp[sid])\n",
        "    task_name = row[\"task_name\"]\n",
        "    # Retrieve hparams and metrics.\n",
        "    hparams = {}\n",
        "    metrics = {}\n",
        "    for k in row.keys():\n",
        "      v = row[k]\n",
        "      if type(v) is float and math.isnan(v):\n",
        "        continue\n",
        "      if k.startswith(\"hparams.\"):\n",
        "        if type(v) == str and (v.startswith(\"{\") or v.startswith(\"[\")):\n",
        "          v = json.loads(v)\n",
        "        hparams[k[len(\"hparams.\"):]] = v\n",
        "      if k.startswith(\"metrics.\"):\n",
        "        metrics[k[len(\"metrics.\"):]] = v\n",
        "    # Create path.\n",
        "    path = Path(hparams, comps, parent=None, agent_id=agent_id, task_name=task_name)\n",
        "    path.metrics = metrics\n",
        "    path.id = path_id\n",
        "    # Add train locks.\n",
        "    for c in path.components:\n",
        "      c.train_locks.add(agent_id)\n",
        "    pop.paths[agent_id].append(path)\n",
        "    sid_2_path[path_sid] = path\n",
        "    if row[\"parent_id\"] \u003e= 0:\n",
        "      parent_sid = f'{row[\"parent_agent_id\"]}:{row[\"parent_id\"]}'\n",
        "      path_2_parent_sid[path] = parent_sid\n",
        "  # Set parents.\n",
        "  for path, parent_sid in path_2_parent_sid.items():\n",
        "    if parent_sid not in sid_2_path:\n",
        "      # This can happen if parent is retired by a parallel agent.\n",
        "      # In this case fall back to root model.\n",
        "      for k in sid_2_path.keys():\n",
        "        if \"root_model\" in k:\n",
        "          parent_sid = k\n",
        "      print(f\"{path.agent_id}:{path.id} orphaned, fallback: {parent_sid}\")\n",
        "    path.parent = sid_2_path[parent_sid]\n",
        "  # Set reference to components representing sub paths.\n",
        "  for path in [p for paths in pop.paths.values() for p in paths]:\n",
        "    if \"paths\" in path.hparams:\n",
        "      for k in path.hparams[\"paths\"]:\n",
        "        sub_path = pop.get_path_from_full_id(path.hparams[\"paths\"][k][\"agent_id\"], path.hparams[\"paths\"][k][\"id\"])\n",
        "        path.components.append(ComponentPath(name=k, path=sub_path))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_0AyzNaTl002"
      },
      "source": [
        "# Training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m6vSjSIvNPq4"
      },
      "outputs": [],
      "source": [
        "@partial(jax.jit, static_argnames=\"model\")\n",
        "def eval_step(params, inputs, labels, model):\n",
        "  logits = model.apply({\"params\": params}, inputs, train=False)\n",
        "  # Avg accuracy on the batch.\n",
        "  return (logits.argmax(axis=-1) == labels).mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_m2xl8XR7cWy"
      },
      "outputs": [],
      "source": [
        "@partial(jax.jit, static_argnames=[\"model\", \"optimizer\", \"format_params_fn\"], donate_argnums=[0, 2])\n",
        "def train_step(params, fixed_params, opt_state, inputs, labels, model, optimizer, format_params_fn):\n",
        "  def loss_fn(params, fixed_params, inputs, labels):\n",
        "    logits = model.apply(\n",
        "        {\"params\": format_params_fn(merge_params(params, fixed_params))},\n",
        "        inputs, train=True)\n",
        "    labels = jax.nn.one_hot(labels, logits.shape[-1])\n",
        "    return -jnp.mean(jnp.sum(labels * nn.log_softmax(logits), axis=-1))\n",
        "  grads = jax.grad(loss_fn)(params, fixed_params, inputs, labels)\n",
        "  updates, opt_state = optimizer.update(grads, opt_state, params=params)\n",
        "  params = optax.apply_updates(params, updates)\n",
        "  return params, opt_state"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RrvyCDaw0KFZ"
      },
      "outputs": [],
      "source": [
        "def execute_train_step(path, train_batch):\n",
        "  path.params_device, path.opt_state_device = train_step(\n",
        "      path.params_device,\n",
        "      path.fixed_params_device,\n",
        "      path.opt_state_device,\n",
        "      train_batch[\"image\"],\n",
        "      train_batch[\"label\"],\n",
        "      path.model,\n",
        "      path.optimizer,\n",
        "      path.model_factory.get_comps2model_fn())\n",
        "\n",
        "def execute_eval_step(path, eval_batch):\n",
        "  path.accs.append(\n",
        "      eval_step(\n",
        "          path.model_factory.get_comps2model_fn()(merge_params(\n",
        "              path.params_device, path.fixed_params_device)),\n",
        "          eval_batch[\"image\"],\n",
        "          eval_batch[\"label\"],\n",
        "          path.model))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lRjic_IpGYJU"
      },
      "outputs": [],
      "source": [
        "PREV_LOOP_END = time.time()\n",
        "\n",
        "def train_loop(paths, ds_train, ds_validation, devices, config):\n",
        "  global PREV_LOOP_END\n",
        "  timing = {}\n",
        "  task = paths[0].task\n",
        "  for path in paths:\n",
        "    assert task.name == path.task_name\n",
        "  for p_id, path in enumerate(paths):\n",
        "    path.device_id = p_id % len(devices)\n",
        "    path.device = devices[path.device_id]\n",
        "    path.optimizer = path.get_optimizer()\n",
        "    path.best_params_local = None\n",
        "    path.best_quality = None\n",
        "    path.best_score = path.parent.score() if path.agent_id == path.parent.agent_id else -np.inf\n",
        "    path.evals = []\n",
        "    path.exe_thread = None\n",
        "  gc.collect()\n",
        "  # Tranfer parameters to devices.\n",
        "  for path in paths:\n",
        "    path.params_device = jax.device_put(path.get_trainable_params(), path.device)\n",
        "    path.fixed_params_device = jax.device_put(path.get_fixed_params(), path.device)\n",
        "    path.opt_state_device = jax.jit(path.optimizer.init, device=path.device)(path.params_device)\n",
        "  iter_ds_validation = iter(ds_validation)\n",
        "  # Train loop.\n",
        "  for t_step, train_batch in zip(\n",
        "      range(config.num_validations_per_path_training\n",
        "            * task.num_train_batches_between_validations),\n",
        "      ds_train):\n",
        "    if t_step == 0:\n",
        "      timing[\"start_train\"] = time.time()\n",
        "    for p_id, path in enumerate(paths):\n",
        "      train_batch = jax.device_put(train_batch, path.device)\n",
        "      if path.exe_thread is not None:\n",
        "        path.exe_thread.join()\n",
        "      path.exe_thread = Thread(target=execute_train_step, args=(path, train_batch))\n",
        "      path.exe_thread.start()\n",
        "    if t_step == 0:\n",
        "      [p.exe_thread.join() for p in paths]\n",
        "      timing[\"end_train_compile\"] = time.time()\n",
        "    # Evaluation on validation set.\n",
        "    if (t_step+1) % task.num_train_batches_between_validations == 0:\n",
        "      for path in paths:\n",
        "        path.accs = []\n",
        "      for e_step, eval_batch in zip(range(task.num_validation_batches), iter_ds_validation):\n",
        "        if e_step == 0:\n",
        "          start_eval_round = time.time()\n",
        "          if \"start_eval\" not in timing:\n",
        "            timing[\"start_eval\"] = start_eval_round\n",
        "        for p_id, path in enumerate(paths):\n",
        "          eval_batch = jax.device_put(eval_batch, path.device)\n",
        "          path.exe_thread.join()\n",
        "          path.exe_thread = Thread(target=execute_eval_step, args=(path,eval_batch))\n",
        "          path.exe_thread.start()\n",
        "        if e_step == 0 and \"end_eval_compile\" not in timing:\n",
        "          [p.exe_thread.join() for p in paths]\n",
        "          timing[\"end_eval_compile\"] = time.time()\n",
        "\n",
        "      # Get params of best models.\n",
        "      qs = []\n",
        "      eval_idx = (t_step+1) // task.num_train_batches_between_validations\n",
        "      for path in paths:\n",
        "        path.exe_thread.join()\n",
        "        quality = np.mean(path.accs)\n",
        "        del path.accs\n",
        "        qs.append(f\"{quality:.4f}\")\n",
        "        path.evals.append(quality)\n",
        "        # Set quality in metrics for current score computation.\n",
        "        path.metrics[\"quality\"] = quality\n",
        "        path_score = path.score()\n",
        "        if path_score \u003e path.best_score:\n",
        "          path.best_params_local = jax.device_get(path.params_device)\n",
        "          path.best_score = path_score\n",
        "          path.best_quality = quality\n",
        "          qs[-1] += \"*\"\n",
        "      time_train = time.time() - PREV_LOOP_END\n",
        "      avg_path_time = (time_train / eval_idx) / len(paths)\n",
        "      print((\"\\t\".join(qs) + f\"\\t\u003c Eval {eval_idx}\").expandtabs(8),\n",
        "            f\"tot:{time_train:.1f}s\", f\"avg/path:{avg_path_time:.1f}s\")\n",
        "      timing[\"time_eval\"] = timing.get(\"time_eval\", 0) + (time.time() - start_eval_round)\n",
        "      del eval_batch\n",
        "  del train_batch\n",
        "  for path in paths:\n",
        "    del path.params_device\n",
        "    del path.fixed_params_device\n",
        "    del path.opt_state_device\n",
        "    del path.optimizer\n",
        "    del path.exe_thread\n",
        "  gc.collect()\n",
        "\n",
        "  timing[\"end_train\"] = time.time()\n",
        "\n",
        "  time_init = timing[\"start_train\"] - PREV_LOOP_END\n",
        "  time_train_compile = timing[\"end_train_compile\"] - timing[\"start_train\"]\n",
        "  time_eval_compile = timing[\"end_eval_compile\"] - timing[\"start_eval\"]\n",
        "  time_eval = timing[\"time_eval\"] - time_eval_compile\n",
        "  time_train = timing[\"end_train\"] - timing[\"end_train_compile\"] - time_eval - time_eval_compile\n",
        "  PREV_LOOP_END = timing[\"end_train\"]\n",
        "\n",
        "  for path in paths:\n",
        "    path.metrics[\"time_init\"] = time_init\n",
        "    path.metrics[\"time_train_compile\"] = time_train_compile\n",
        "    path.metrics[\"time_eval_compile\"] = time_eval_compile\n",
        "    path.metrics[\"time_train\"] = time_train\n",
        "    path.metrics[\"time_eval\"] = time_eval\n",
        "    path.metrics[\"timestamp_end\"] = PREV_LOOP_END\n",
        "    path.metrics[\"num_params\"] = get_num_params(path.get_all_params())\n",
        "    path.metrics[\"num_trainable_params\"] = get_num_params(path.get_trainable_params())\n",
        "    path.metrics[\"quality\"] = max(path.evals)\n",
        "    path.metrics[\"evals\"] = json.dumps([float(v) for v in path.evals])\n",
        "\n",
        "    if path.best_params_local != None:\n",
        "      path.metrics[\"improved\"] = True\n",
        "      path.update_trainable(path.best_params_local)\n",
        "      assert path.best_quality == path.metrics[\"quality\"]\n",
        "      assert path.best_score == path.score()\n",
        "    else:\n",
        "      path.metrics[\"improved\"] = False\n",
        "      # Sampled path will be dropped if not improved, so skip paramter update.\n",
        "      assert path.best_quality == None\n",
        "\n",
        "    del path.best_params_local\n",
        "    del path.best_score\n",
        "    del path.best_quality\n",
        "    del path.evals\n",
        "\n",
        "  pqs = []\n",
        "  qs = []\n",
        "  psc = []\n",
        "  sc = []\n",
        "  for path in paths:\n",
        "    if path.task_name == path.parent.task_name:\n",
        "      metric_suffix = \"\" if path.agent_id == path.parent.agent_id else \"A\"\n",
        "      pqs.append(f\"{path.parent.metrics['quality']:.4f}{metric_suffix}\")\n",
        "      psc.append(f\"{path.parent.score():.4f}{metric_suffix}\")\n",
        "    else:\n",
        "      pqs.append(\"NEW\")\n",
        "      psc.append(\"NEW\")\n",
        "    qs.append(f\"{path.metrics['quality']:.4f}\")\n",
        "    sc.append(f\"{path.score():.4f}\")\n",
        "    if path.metrics[\"improved\"]:\n",
        "      sc[-1] += \"+\"\n",
        "\n",
        "  print((\"\\t\".join([f\"{path.parent.id}\" for path in paths]) + \"\\t\u003c Parent id\").expandtabs(8))\n",
        "  print((\"\\t\".join([f\"{path.id}\" for path in paths]) + \"\\t\u003c Path id\").expandtabs(8))\n",
        "  print((\"\\t\".join(pqs) + \"\\t\u003c Parent best quality\").expandtabs(8))\n",
        "  print((\"\\t\".join(qs) + \"\\t\u003c Path best quality\").expandtabs(8))\n",
        "  print((\"\\t\".join(psc) + \"\\t\u003c Parent score\").expandtabs(8))\n",
        "  print((\"\\t\".join(sc) + \"\\t\u003c Path score\").expandtabs(8))\n",
        "  print(\"time\\tINIT\\tCOMPtrn\\tCOMPevl\\tTRN\\tEVAL\".expandtabs(8))\n",
        "  print(f\"(s)\\t{time_init:.1f}\\t{time_train_compile:.1f}\\t{time_eval_compile:.1f}\\t{time_train:.1f}\\t{time_eval:.1f}\".expandtabs(8))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JVT8nwIWMAVf"
      },
      "outputs": [],
      "source": [
        "def has_test_quality(path):\n",
        "  return (\"test_quality\" in path.metrics and not math.isnan(path.metrics[\"test_quality\"]))\n",
        "\n",
        "# Run final eval on test set.\n",
        "def run_test_eval(path, test_immutability=False):\n",
        "  # Skip if test_quality already computed and no immutability test required.\n",
        "  if not test_immutability and has_test_quality(path):\n",
        "    return\n",
        "  eval_st = time.time()\n",
        "  ds_test = path.task.get_ds(\"test\", path.hparams)\n",
        "  params = path.get_all_params()\n",
        "  # Running on same device can allow to reuse the fn compiled for validation.\n",
        "  if not hasattr(path, \"device\"):\n",
        "    path.device = random.choice(jax.local_devices())\n",
        "  params_device = jax.device_put(path.model_factory.get_comps2model_fn()(params), path.device)\n",
        "  acc_sum = []\n",
        "  tot_num_samples = 0\n",
        "  # Warning: if repeat() is called on this dataset, then this loop never ends.\n",
        "  for batch in ds_test:\n",
        "    acc_avg = eval_step(params_device, batch[\"image\"], batch[\"label\"], path.model)\n",
        "    batch_size = batch[\"label\"].shape[0]\n",
        "    # Need to recompute sum because last batch can have different size to allow\n",
        "    # for exact eval on the test set.\n",
        "    acc_sum.append(acc_avg * batch_size)\n",
        "    tot_num_samples += batch_size\n",
        "  del params_device\n",
        "  acc_avg = np.sum(acc_sum) / tot_num_samples\n",
        "  # Assert test quality equivalence to test immutability.\n",
        "  if has_test_quality(path):\n",
        "    print(f\"Testing immutability of path {path.id} : {path.metrics['test_quality']} ~= {acc_avg}\")\n",
        "    assert test_immutability\n",
        "    if not np.isclose(path.metrics[\"test_quality\"], acc_avg, rtol=IMMUTABILITY_RELATIVE_TOLLERANCE):\n",
        "      print(\"WARNING IMMUTABILITY TEST FAILED, delta:\", acc_avg-path.metrics[\"test_quality\"])\n",
        "    assert np.isclose(path.metrics[\"test_quality\"], acc_avg), \\\n",
        "        f\"{path.task_name} {path.metrics['test_quality']} {acc_avg}\"\n",
        "  path.metrics[\"test_quality\"] = acc_avg\n",
        "  print(f\"TEST QUALITY: {acc_avg}\\nTEST TIME: {time.time()-eval_st:.2f}s\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZubWF7rhmHju"
      },
      "source": [
        "# Main"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wzO_0X5-S0yj"
      },
      "outputs": [],
      "source": [
        "agent = get_agent_class(AGENT)(system_state_dir=SYSTEM_STATE_DIR, task_name=TASK_NAME, num_cycles_max=NUM_CYCLES_MAX)\n",
        "agent.run()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "name": "mu4Net: Multipath Multiagent Multitask Mutant Net",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
